1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <cmath>
18
19#include "tensorflow/core/framework/common_shape_fns.h"
20#include "tensorflow/core/framework/kernel_shape_util.h"
21#include "tensorflow/core/framework/numeric_op.h"
22#include "tensorflow/core/framework/op.h"
23#include "tensorflow/core/framework/shape_inference.h"
24#include "tensorflow/core/lib/core/bits.h"
25#include "tensorflow/core/lib/math/math_util.h"
26#include "tensorflow/core/util/mirror_pad_mode.h"
27#include "tensorflow/core/util/padding.h"
28#include "tensorflow/core/util/tensor_format.h"
29
30namespace tensorflow {
31
32using shape_inference::DimensionHandle;
33using shape_inference::InferenceContext;
34using shape_inference::ShapeHandle;
35
36namespace {
37
38Status FractionalPoolShapeFn(InferenceContext* c) {
39 ShapeHandle input;
40 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
41
42 std::vector<float> pooling_ratio;
43 TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio));
44 if (pooling_ratio.size() != 4) {
45 return errors::InvalidArgument(
46 "pooling_ratio field must specify 4 dimensions");
47 }
48 std::vector<DimensionHandle> output_dims;
49 for (int i = 0; i < 4; ++i) {
50 DimensionHandle d = c->Dim(input, i);
51 if (c->ValueKnown(d)) {
52 // This must match the same logic in the kernel function in
53 // core/kernels/fractional_max_pool_op.cc.
54 auto val =
55 static_cast<int64_t>(std::floor(c->Value(d) / pooling_ratio[i]));
56 if (val < 0) {
57 return errors::InvalidArgument("Size computed for dim ", i,
58 " is negative: ", val);
59 }
60 output_dims.push_back(c->MakeDim(val));
61 } else {
62 output_dims.push_back(c->UnknownDim());
63 }
64 }
65
66 c->set_output(0, c->MakeShape(output_dims));
67 c->set_output(1, c->Vector(output_dims[1]));
68 c->set_output(2, c->Vector(output_dims[2]));
69 return OkStatus();
70}
71
72} // namespace
73
74// --------------------------------------------------------------------------
75
76REGISTER_OP("AvgPool")
77 .Input("value: T")
78 .Output("output: T")
79 .Attr("ksize: list(int) >= 4")
80 .Attr("strides: list(int) >= 4")
81 .Attr(GetPaddingAttrString())
82 .Attr(GetConvnetDataFormatAttrString())
83 .Attr("T: {half, bfloat16, float, double}")
84 .SetShapeFn(shape_inference::AvgPoolShape);
85
86REGISTER_OP("AvgPoolGrad")
87 .Input("orig_input_shape: int32")
88 .Input("grad: T")
89 .Output("output: T")
90 .Attr("ksize: list(int) >= 4")
91 .Attr("strides: list(int) >= 4")
92 .Attr(GetPaddingAttrString())
93 .Attr(GetConvnetDataFormatAttrString())
94 .Attr("T: {half, bfloat16, float, double}")
95 .SetShapeFn(shape_inference::AvgPoolGradShape);
96
97// --------------------------------------------------------------------------
98
99REGISTER_OP("BatchNormWithGlobalNormalization")
100 .Input("t: T")
101 .Input("m: T")
102 .Input("v: T")
103 .Input("beta: T")
104 .Input("gamma: T")
105 .Output("result: T")
106 .Attr("T: numbertype")
107 .Attr("variance_epsilon: float")
108 .Attr("scale_after_normalization: bool")
109 .Deprecated(9, "Use tf.nn.batch_normalization()")
110 .SetShapeFn([](InferenceContext* c) {
111 ShapeHandle input;
112 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
113
114 DimensionHandle last_dim = c->Dim(input, 3);
115 for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
116 ShapeHandle vec;
117 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
118 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
119 }
120
121 ShapeHandle out;
122 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
123 c->set_output(0, out);
124 return OkStatus();
125 });
126
127REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
128 .Input("t: T")
129 .Input("m: T")
130 .Input("v: T")
131 .Input("gamma: T")
132 .Input("backprop: T")
133 .Output("dx: T")
134 .Output("dm: T")
135 .Output("dv: T")
136 .Output("db: T")
137 .Output("dg: T")
138 .Attr("T: numbertype")
139 .Attr("variance_epsilon: float")
140 .Attr("scale_after_normalization: bool")
141 .Deprecated(9, "Use tf.nn.batch_normalization()")
142 .SetShapeFn([](InferenceContext* c) {
143 ShapeHandle input;
144 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
145 TF_RETURN_IF_ERROR(
146 c->Merge(input, c->input(4), &input)); // with backprop
147
148 DimensionHandle last_dim = c->Dim(input, 3);
149 for (int i = 1; i < 4; ++i) { // covers m, v, gamma
150 ShapeHandle vec;
151 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
152 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
153 }
154
155 ShapeHandle dx;
156 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
157 c->set_output(0, dx);
158
159 ShapeHandle vector_shape = c->Vector(last_dim);
160 c->set_output(1, vector_shape);
161 c->set_output(2, vector_shape);
162 c->set_output(3, vector_shape);
163 c->set_output(4, vector_shape);
164 return OkStatus();
165 });
166
167// --------------------------------------------------------------------------
168
169REGISTER_OP("FusedBatchNorm")
170 .Input("x: T")
171 .Input("scale: T")
172 .Input("offset: T")
173 .Input("mean: T")
174 .Input("variance: T")
175 .Output("y: T")
176 .Output("batch_mean: T")
177 .Output("batch_variance: T")
178 .Output("reserve_space_1: T")
179 .Output("reserve_space_2: T")
180 .Attr("T: {float}")
181 .Attr("epsilon: float = 0.0001")
182 .Attr("exponential_avg_factor: float = 1.0")
183 .Attr(GetConvnetDataFormatAttrString())
184 .Attr("is_training: bool = true")
185 .SetShapeFn(shape_inference::FusedBatchNormShape);
186
187REGISTER_OP("FusedBatchNormV2")
188 .Input("x: T")
189 .Input("scale: U")
190 .Input("offset: U")
191 .Input("mean: U")
192 .Input("variance: U")
193 .Output("y: T")
194 .Output("batch_mean: U")
195 .Output("batch_variance: U")
196 .Output("reserve_space_1: U")
197 .Output("reserve_space_2: U")
198 .Attr("T: {half, bfloat16, float}")
199 .Attr("U: {float}")
200 .Attr("epsilon: float = 0.0001")
201 .Attr("exponential_avg_factor: float = 1.0")
202 .Attr(GetConvnetDataFormatAttrString())
203 .Attr("is_training: bool = true")
204 .SetShapeFn(shape_inference::FusedBatchNormShape);
205
206REGISTER_OP("FusedBatchNormV3")
207 .Input("x: T")
208 .Input("scale: U")
209 .Input("offset: U")
210 .Input("mean: U")
211 .Input("variance: U")
212 .Output("y: T")
213 .Output("batch_mean: U")
214 .Output("batch_variance: U")
215 .Output("reserve_space_1: U")
216 .Output("reserve_space_2: U")
217 .Output("reserve_space_3: U")
218 .Attr("T: {half, bfloat16, float}")
219 .Attr("U: {bfloat16, float}")
220 .Attr("epsilon: float = 0.0001")
221 .Attr("exponential_avg_factor: float = 1.0")
222 .Attr(GetConvnetDataFormat2D3DAttrString())
223 .Attr("is_training: bool = true")
224 .SetShapeFn(shape_inference::FusedBatchNormV3Shape);
225
226REGISTER_OP("_FusedBatchNormEx")
227 .Input("x: T")
228 .Input("scale: U")
229 .Input("offset: U")
230 .Input("mean: U")
231 .Input("variance: U")
232 .Input("side_input: num_side_inputs * T")
233 .Output("y: T")
234 .Output("batch_mean: U")
235 .Output("batch_variance: U")
236 .Output("reserve_space_1: U")
237 .Output("reserve_space_2: U")
238 .Output("reserve_space_3: U")
239 .Attr("T: {half, float, bfloat16}")
240 .Attr("U: {float}")
241 .Attr("epsilon: float = 0.0001")
242 .Attr("exponential_avg_factor: float = 1.0")
243 .Attr("num_side_inputs: int >= 0 = 0")
244 .Attr("activation_mode: string = \"Identity\"")
245 .Attr(GetConvnetDataFormatAttrString())
246 .Attr("is_training: bool = true")
247 .SetShapeFn(shape_inference::FusedBatchNormExShape)
248 .Doc(R"doc(
249Internal FusedBatchNorm operation: reserved for internal use.
250
251Do not invoke this operator directly in Python. A fusion optimization is
252expected to create these operators.
253)doc");
254
255REGISTER_OP("FusedBatchNormGrad")
256 .Input("y_backprop: T")
257 .Input("x: T")
258 .Input("scale: T")
259 .Input("reserve_space_1: T")
260 .Input("reserve_space_2: T")
261 .Output("x_backprop: T")
262 .Output("scale_backprop: T")
263 .Output("offset_backprop: T")
264 .Output("reserve_space_3: T")
265 .Output("reserve_space_4: T")
266 .Attr("T: {float}")
267 .Attr("epsilon: float = 0.0001")
268 .Attr(GetConvnetDataFormatAttrString())
269 .Attr("is_training: bool = true")
270 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
271
272REGISTER_OP("FusedBatchNormGradV2")
273 .Input("y_backprop: T")
274 .Input("x: T")
275 .Input("scale: float")
276 .Input("reserve_space_1: U")
277 .Input("reserve_space_2: U")
278 .Output("x_backprop: T")
279 .Output("scale_backprop: U")
280 .Output("offset_backprop: U")
281 .Output("reserve_space_3: U")
282 .Output("reserve_space_4: U")
283 .Attr("T: {half, bfloat16, float}")
284 .Attr("U: {float}")
285 .Attr("epsilon: float = 0.0001")
286 .Attr(GetConvnetDataFormatAttrString())
287 .Attr("is_training: bool = true")
288 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
289
290REGISTER_OP("FusedBatchNormGradV3")
291 .Input("y_backprop: T")
292 .Input("x: T")
293 .Input("scale: float")
294 .Input("reserve_space_1: U")
295 .Input("reserve_space_2: U")
296 .Input("reserve_space_3: U")
297 .Output("x_backprop: T")
298 .Output("scale_backprop: U")
299 .Output("offset_backprop: U")
300 .Output("reserve_space_4: U")
301 .Output("reserve_space_5: U")
302 .Attr("T: {half, bfloat16, float}")
303 .Attr("U: {float}")
304 .Attr("epsilon: float = 0.0001")
305 .Attr(GetConvnetDataFormat2D3DAttrString())
306 .Attr("is_training: bool = true")
307 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
308
309REGISTER_OP("_FusedBatchNormGradEx")
310 .Input("y_backprop: T")
311 .Input("x: T")
312 .Input("scale: float")
313 .Input("reserve_space_1: U")
314 .Input("reserve_space_2: U")
315 .Input("reserve_space_3: U")
316 .Input("offset: float")
317 .Input("y: T")
318 .Output("x_backprop: T")
319 .Output("scale_backprop: U")
320 .Output("offset_backprop: U")
321 .Output("reserve_space_4: U")
322 .Output("reserve_space_5: U")
323 .Output("side_input_backprop: num_side_inputs * T")
324 .Attr("T: {half, float}")
325 .Attr("U: {float}")
326 .Attr("epsilon: float = 0.0001")
327 .Attr("num_side_inputs: int >= 0 = 0")
328 .Attr("activation_mode: string = \"Identity\"")
329 .Attr(GetConvnetDataFormat2D3DAttrString())
330 .Attr("is_training: bool = true")
331 .SetShapeFn(shape_inference::FusedBatchNormGradExShape)
332 .Doc(R"doc(
333Internal FusedBatchNormGrad operation: reserved for internal use.
334
335Do not invoke this operator directly in Python. A fusion optimization is
336expected to create these operators.
337)doc");
338// --------------------------------------------------------------------------
339
340REGISTER_OP("BiasAdd")
341 .Attr("T: numbertype")
342 .Input("value: T")
343 .Input("bias: T")
344 .Attr(GetConvnetDataFormatAttrString())
345 .Output("output: T")
346 .SetShapeFn(shape_inference::BiasAddShape);
347// --------------------------------------------------------------------------
348
349REGISTER_OP("BiasAddGrad")
350 .Attr("T: numbertype")
351 .Input("out_backprop: T")
352 .Attr(GetConvnetDataFormatAttrString())
353 .Output("output: T")
354 .SetShapeFn(shape_inference::BiasAddGradShape);
355// --------------------------------------------------------------------------
356
357REGISTER_OP("BiasAddV1")
358 .Attr("T: numbertype")
359 .Input("value: T")
360 .Input("bias: T")
361 .Output("output: T")
362 .SetShapeFn(shape_inference::BiasAddShape);
363// --------------------------------------------------------------------------
364
365REGISTER_OP("Conv2D")
366 .Input("input: T")
367 .Input("filter: T")
368 .Output("output: T")
369 .Attr("T: {half, bfloat16, float, double, int32}")
370 .Attr("strides: list(int)")
371 .Attr("use_cudnn_on_gpu: bool = true")
372 .Attr(GetPaddingAttrStringWithExplicit())
373 .Attr(GetExplicitPaddingsAttrString())
374 .Attr(GetConvnetDataFormatAttrString())
375 .Attr("dilations: list(int) = [1, 1, 1, 1]")
376 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding);
377
378REGISTER_OP("Conv2DBackpropInput")
379 .Input("input_sizes: int32")
380 .Input("filter: T")
381 .Input("out_backprop: T")
382 .Output("output: T")
383 .Attr("T: {half, bfloat16, float, double, int32}")
384 .Attr("strides: list(int)")
385 .Attr("use_cudnn_on_gpu: bool = true")
386 .Attr(GetPaddingAttrStringWithExplicit())
387 .Attr(GetExplicitPaddingsAttrString())
388 .Attr(GetConvnetDataFormatAttrString())
389 .Attr("dilations: list(int) = [1, 1, 1, 1]")
390 .SetShapeFn(shape_inference::Conv2DBackpropInputShape);
391
392// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
393// more general string attribute ('kernel_impl'?) that can be used to
394// select among several possible implementations.
395REGISTER_OP("Conv2DBackpropFilter")
396 .Input("input: T")
397 .Input("filter_sizes: int32")
398 .Input("out_backprop: T")
399 .Output("output: T")
400 .Attr("T: {half, bfloat16, float, double}")
401 .Attr("strides: list(int)")
402 .Attr("use_cudnn_on_gpu: bool = true")
403 .Attr(GetPaddingAttrStringWithExplicit())
404 .Attr(GetExplicitPaddingsAttrString())
405 .Attr(GetConvnetDataFormatAttrString())
406 .Attr("dilations: list(int) = [1, 1, 1, 1]")
407 .SetShapeFn([](InferenceContext* c) {
408 ShapeHandle s;
409 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
410 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
411 c->set_output(0, s);
412 return OkStatus();
413 });
414
415REGISTER_OP("_FusedConv2D")
416 .Input("input: T")
417 .Input("filter: T")
418 .Input("args: TArgs")
419 .Input("host_args : num_host_args * float")
420 .Output("output: T")
421 .Attr("T: {half, float, double, int8, qint8}")
422 .Attr("TArgs: list(type)")
423 .Attr("num_args: int >= 0")
424 .Attr("num_host_args: int >= 0 =0")
425 .Attr("strides: list(int)")
426 .Attr(GetPaddingAttrStringWithExplicit())
427 .Attr(GetExplicitPaddingsAttrString())
428 .Attr("data_format: { 'NHWC', 'NCHW', 'NCHW_VECT_C' } = 'NHWC'")
429 .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
430 .Attr("dilations: list(int) = [1, 1, 1, 1]")
431 .Attr("use_cudnn_on_gpu: bool = true")
432 .Attr("fused_ops: list(string) = []")
433 // Attributes for the FusedBatchNorm ------------------------------------ //
434 .Attr("epsilon: float = 0.0001")
435 // Attributes for the LeakyRelu ----------------------------------------- //
436 .Attr("leakyrelu_alpha: float = 0.2")
437 // ---------------------------------------------------------------------- //
438 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
439 .Doc(R"doc(
440Performs a convolution followed by a specified series of operations.
441
442The inputs to the convolution are `input` and `filter`. The series of operations
443that follows is specified by the `fused_ops` attribute, which is a list of TF op
444names specified as strings (e.g. "Relu"). They are performed in order, where the
445(first) input to each op is the output of the preceding op. The first input and
446the output of each fused_op must be of type T.
447
448Currently supported fused_op combinations are: [X] and [X,A], where X is one of
449{"BiasAdd","FusedBatchNorm"} and A is one of {"Elu","Relu","Relu6"}.
450
451* The first input to op X is the Conv2D result, and the additional input(s) to X
452are specified by `args`.
453* If there is an op A specified, the output of op X is the input to op A, and op
454A produces the _FusedConv2D output. Otherwise, op X produces the _FusedConv2D
455output.
456
457*NOTE*: Do not invoke this operator directly in Python. Grappler is expected to
458create these operators.
459)doc");
460
461namespace {
462
463Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
464 ShapeHandle input;
465 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
466
467 ShapeHandle resized = input;
468 int paddings_index = 1;
469 int filter_index = 2;
470 if (has_resize) {
471 paddings_index = 2;
472 filter_index = 3;
473
474 ShapeHandle unused_size;
475 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size));
476
477 const Tensor* size = c->input_tensor(1);
478 DimensionHandle new_height = c->UnknownDim();
479 DimensionHandle new_width = c->UnknownDim();
480 if (size != nullptr) {
481 new_height = c->MakeDim(size->flat<int32>()(0));
482 new_width = c->MakeDim(size->flat<int32>()(1));
483 }
484 TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized));
485 TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized));
486 }
487
488 ShapeHandle paddings;
489 TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings));
490 TF_RETURN_IF_ERROR(
491 c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized));
492 TF_RETURN_IF_ERROR(
493 c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings));
494
495 const Tensor* paddings_t = c->input_tensor(paddings_index);
496 ShapeHandle padded;
497 if (paddings_t != nullptr) {
498 std::vector<DimensionHandle> output_dims;
499 for (int i = 0; i < 4; ++i) {
500 DimensionHandle dim = c->Dim(resized, i);
501 int64_t p0 = static_cast<int64_t>(paddings_t->matrix<int32>()(i, 0));
502 int64_t p1 = static_cast<int64_t>(paddings_t->matrix<int32>()(i, 1));
503 if (p0 < 0 || p1 < 0) {
504 return errors::InvalidArgument("Paddings must be non-negative");
505 }
506
507 TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim));
508 output_dims.push_back(dim);
509 }
510 padded = c->MakeShape(output_dims);
511 } else {
512 padded = c->UnknownShapeOfRank(4);
513 }
514
515 // Work out the convolution's effect with 'padded' as the input.
516 ShapeHandle filter;
517 TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter));
518 std::vector<int32> strides;
519 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
520 if (strides.size() != 4) {
521 return errors::InvalidArgument(
522 "Operation requires the stride attribute to contain 4 values, but ",
523 "got: ", strides.size());
524 }
525
526 int32_t stride_rows = strides[1];
527 int32_t stride_cols = strides[2];
528
529 DimensionHandle batch_size_dim = c->Dim(padded, 0);
530 DimensionHandle in_rows_dim = c->Dim(padded, 1);
531 DimensionHandle in_cols_dim = c->Dim(padded, 2);
532 DimensionHandle filter_rows_dim = c->Dim(filter, 0);
533 DimensionHandle filter_cols_dim = c->Dim(filter, 1);
534 DimensionHandle output_depth_dim = c->Dim(filter, 3);
535
536 DimensionHandle unused;
537 TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused));
538
539 Padding padding;
540 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
541
542 DimensionHandle output_rows, output_cols;
543 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
544 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
545 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
546 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
547
548 ShapeHandle output_shape = c->MakeShape(
549 {batch_size_dim, output_rows, output_cols, output_depth_dim});
550 c->set_output(0, output_shape);
551 return OkStatus();
552}
553
554} // namespace
555
556REGISTER_OP("DataFormatDimMap")
557 .Input("x: T")
558 .Output("y: T")
559 .Attr("T: {int32, int64} = DT_INT32")
560 .Attr("src_format: string = 'NHWC'")
561 .Attr("dst_format: string = 'NCHW'")
562 .SetShapeFn(shape_inference::UnchangedShape);
563
564REGISTER_OP("DataFormatVecPermute")
565 .Input("x: T")
566 .Output("y: T")
567 .Attr("T: {int32, int64} = DT_INT32")
568 .Attr("src_format: string = 'NHWC'")
569 .Attr("dst_format: string = 'NCHW'")
570 .SetShapeFn(shape_inference::UnchangedShape);
571
572REGISTER_OP("FusedResizeAndPadConv2D")
573 .Input("input: T")
574 .Input("size: int32")
575 .Input("paddings: int32")
576 .Input("filter: T")
577 .Output("output: T")
578 .Attr("T: {half, float, double}")
579 .Attr("resize_align_corners: bool = false")
580 .Attr(GetMirrorPadModeAttrString())
581 .Attr("strides: list(int)")
582 .Attr(GetPaddingAttrString())
583 .SetShapeFn([](InferenceContext* c) {
584 return CommonFusedConvCalculations(c, true /* has_resize */);
585 });
586
587REGISTER_OP("FusedPadConv2D")
588 .Input("input: T")
589 .Input("paddings: int32")
590 .Input("filter: T")
591 .Output("output: T")
592 .Attr("T: {half, float, double}")
593 .Attr(GetMirrorPadModeAttrString())
594 .Attr("strides: list(int)")
595 .Attr(GetPaddingAttrString())
596 .SetShapeFn([](InferenceContext* c) {
597 return CommonFusedConvCalculations(c, false /* has_resize */);
598 });
599
600// --------------------------------------------------------------------------
601
602REGISTER_OP("DepthwiseConv2dNative")
603 .Input("input: T")
604 .Input("filter: T")
605 .Output("output: T")
606 .Attr("T: {half, bfloat16, float, double}")
607 .Attr("strides: list(int)")
608 .Attr(GetPaddingAttrStringWithExplicit())
609 .Attr(GetExplicitPaddingsAttrString())
610 .Attr(GetConvnetDataFormatAttrString())
611 .Attr("dilations: list(int) = [1, 1, 1, 1]")
612 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding);
613
614REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
615 .Input("input_sizes: int32")
616 .Input("filter: T")
617 .Input("out_backprop: T")
618 .Output("output: T")
619 .Attr("T: {half, bfloat16, float, double}")
620 .Attr("strides: list(int)")
621 .Attr(GetPaddingAttrStringWithExplicit())
622 .Attr(GetExplicitPaddingsAttrString())
623 .Attr(GetConvnetDataFormatAttrString())
624 .Attr("dilations: list(int) = [1, 1, 1, 1]")
625 .SetShapeFn([](InferenceContext* c) {
626 ShapeHandle s;
627 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
628 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
629 c->set_output(0, s);
630 return OkStatus();
631 });
632
633REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
634 .Input("input: T")
635 .Input("filter_sizes: int32")
636 .Input("out_backprop: T")
637 .Output("output: T")
638 .Attr("T: {half, bfloat16, float, double}")
639 .Attr("strides: list(int)")
640 .Attr(GetPaddingAttrStringWithExplicit())
641 .Attr(GetExplicitPaddingsAttrString())
642 .Attr(GetConvnetDataFormatAttrString())
643 .Attr("dilations: list(int) = [1, 1, 1, 1]")
644 .SetShapeFn([](InferenceContext* c) {
645 ShapeHandle s;
646 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
647 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
648 c->set_output(0, s);
649 return OkStatus();
650 });
651
652REGISTER_OP("_FusedDepthwiseConv2dNative")
653 .Input("input: T")
654 .Input("filter: T")
655 .Input("args: num_args * T")
656 .Output("output: T")
657 .Attr("T: {half, bfloat16, float, double}")
658 .Attr("num_args: int >= 0")
659 .Attr("strides: list(int)")
660 .Attr(GetPaddingAttrString())
661 .Attr(GetConvnetDataFormatAttrString())
662 .Attr("dilations: list(int) = [1, 1, 1, 1]")
663 .Attr("fused_ops: list(string) = []")
664 // Attributes for the FusedBatchNorm ------------------------------------ //
665 .Attr("epsilon: float = 0.0001")
666 // Attributes for the LeakyRelu ----------------------------------------- //
667 .Attr("leakyrelu_alpha: float = 0.2")
668 // ---------------------------------------------------------------------- //
669 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
670
671// --------------------------------------------------------------------------
672
673REGISTER_OP("Conv3D")
674 .Input("input: T")
675 .Input("filter: T")
676 .Output("output: T")
677 .Attr("T: {half, bfloat16, float, double}")
678 .Attr("strides: list(int) >= 5")
679 .Attr(GetPaddingAttrString())
680 .Attr(GetConvnet3dDataFormatAttrString())
681 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
682 .SetShapeFn(shape_inference::Conv3DShape);
683
684REGISTER_OP("Conv3DBackpropInput")
685 .Input("input: T")
686 .Input("filter: T")
687 .Input("out_backprop: T")
688 .Output("output: T")
689 .Attr("T: {half, float, double}")
690 .Attr("strides: list(int) >= 5")
691 .Attr(GetPaddingAttrString())
692 .Deprecated(10, "Use Conv3DBackpropInputV2")
693 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
694 .SetShapeFn([](InferenceContext* c) {
695 return UnchangedShapeWithRank(c, 5);
696 });
697
698REGISTER_OP("Conv3DBackpropFilter")
699 .Input("input: T")
700 .Input("filter: T")
701 .Input("out_backprop: T")
702 .Output("output: T")
703 .Attr("T: {half, float, double}")
704 .Attr("strides: list(int) >= 5")
705 .Attr(GetPaddingAttrString())
706 .Deprecated(10, "Use Conv3DBackpropFilterV2")
707 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
708 .SetShapeFn([](InferenceContext* c) {
709 ShapeHandle out;
710 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out));
711 c->set_output(0, out);
712 return OkStatus();
713 });
714
715REGISTER_OP("Conv3DBackpropInputV2")
716 .Input("input_sizes: Tshape")
717 .Input("filter: T")
718 .Input("out_backprop: T")
719 .Output("output: T")
720 .Attr("T: {half, bfloat16, float, double}")
721 .Attr("strides: list(int) >= 5")
722 .Attr(GetPaddingAttrString())
723 .Attr(GetConvnet3dDataFormatAttrString())
724 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
725 .Attr("Tshape: {int32, int64} = DT_INT32")
726 .SetShapeFn([](InferenceContext* c) {
727 ShapeHandle s;
728 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
729 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
730 c->set_output(0, s);
731 return OkStatus();
732 });
733
734REGISTER_OP("Conv3DBackpropFilterV2")
735 .Input("input: T")
736 .Input("filter_sizes: int32")
737 .Input("out_backprop: T")
738 .Output("output: T")
739 .Attr("T: {half, bfloat16, float, double}")
740 .Attr("strides: list(int) >= 5")
741 .Attr(GetPaddingAttrString())
742 .Attr(GetConvnet3dDataFormatAttrString())
743 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
744 .SetShapeFn([](InferenceContext* c) {
745 ShapeHandle s;
746 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
747 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
748 c->set_output(0, s);
749 return OkStatus();
750 });
751
752// --------------------------------------------------------------------------
753
754REGISTER_OP("AvgPool3D")
755 .Input("input: T")
756 .Output("output: T")
757 .Attr("ksize: list(int) >= 5")
758 .Attr("strides: list(int) >= 5")
759 .Attr(GetPaddingAttrString())
760 .Attr(GetConvnet3dDataFormatAttrString())
761 .Attr("T: {half, bfloat16, float, double}")
762 .SetShapeFn(shape_inference::Pool3DShape);
763
764REGISTER_OP("AvgPool3DGrad")
765 .Input("orig_input_shape: int32")
766 .Input("grad: T")
767 .Output("output: T")
768 .Attr("ksize: list(int) >= 5")
769 .Attr("strides: list(int) >= 5")
770 .Attr(GetPaddingAttrString())
771 .Attr(GetConvnet3dDataFormatAttrString())
772 .Attr("T: {half, bfloat16, float, double}")
773 .SetShapeFn(shape_inference::AvgPool3DGradShape);
774
775// --------------------------------------------------------------------------
776
777REGISTER_OP("MaxPool3D")
778 .Input("input: T")
779 .Output("output: T")
780 .Attr("ksize: list(int) >= 5")
781 .Attr("strides: list(int) >= 5")
782 .Attr(GetPaddingAttrString())
783 .Attr(GetConvnet3dDataFormatAttrString())
784 .Attr("T: {half, bfloat16, float}")
785 .SetShapeFn(shape_inference::Pool3DShape);
786
787REGISTER_OP("MaxPool3DGrad")
788 .Input("orig_input: TInput")
789 .Input("orig_output: TInput")
790 .Input("grad: T")
791 .Output("output: T")
792 .Attr("ksize: list(int) >= 5")
793 .Attr("strides: list(int) >= 5")
794 .Attr(GetPaddingAttrString())
795 .Attr(GetConvnet3dDataFormatAttrString())
796 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
797 .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
798 .SetShapeFn(shape_inference::MaxPool3DGradShape);
799
800REGISTER_OP("MaxPool3DGradGrad")
801 .Input("orig_input: T")
802 .Input("orig_output: T")
803 .Input("grad: T")
804 .Output("output: T")
805 .Attr("ksize: list(int) >= 5 ")
806 .Attr("strides: list(int) >= 5")
807 .Attr(GetPaddingAttrString())
808 .Attr(GetConvnet3dDataFormatAttrString())
809 .Attr("T: realnumbertype")
810 .SetShapeFn([](InferenceContext* c) {
811 TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
812 ShapeHandle unused;
813 // Validate 'orig_input' is the same shape as 'grad'
814 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
815 // Validate 'orig_output' is same shape as 'output'
816 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
817 return OkStatus();
818 });
819
820// --------------------------------------------------------------------------
821
822REGISTER_OP("L2Loss")
823 .Input("t: T")
824 .Output("output: T")
825 .Attr("T: {half, bfloat16, float, double}")
826 .SetShapeFn(shape_inference::ScalarShape);
827
828// --------------------------------------------------------------------------
829
830REGISTER_OP("LRN")
831 .Input("input: T")
832 .Output("output: T")
833 .Attr("depth_radius: int = 5")
834 .Attr("bias: float = 1.0")
835 .Attr("alpha: float = 1.0")
836 .Attr("beta: float = 0.5")
837 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
838 .SetShapeFn([](InferenceContext* c) {
839 return UnchangedShapeWithRank(c, 4);
840 });
841
842REGISTER_OP("LRNGrad")
843 .Input("input_grads: T")
844 .Input("input_image: T")
845 .Input("output_image: T")
846 .Output("output: T")
847 .Attr("depth_radius: int = 5")
848 .Attr("bias: float = 1.0")
849 .Attr("alpha: float = 1.0")
850 .Attr("beta: float = 0.5")
851 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
852 .SetShapeFn([](InferenceContext* c) {
853 ShapeHandle s;
854 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
855 TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
856 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
857 c->set_output(0, s);
858 return OkStatus();
859 });
860
861// --------------------------------------------------------------------------
862
863REGISTER_OP("MaxPool")
864 .Attr(
865 "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
866 "uint16, qint8} = DT_FLOAT")
867 .Attr("ksize: list(int) >= 4")
868 .Attr("strides: list(int) >= 4")
869 .Attr(GetPaddingAttrStringWithExplicit())
870 .Attr(GetExplicitPaddingsAttrString())
871 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
872 .Input("input: T")
873 .Output("output: T")
874 .SetShapeFn(shape_inference::MaxPoolShapeWithExplicitPadding);
875
876REGISTER_OP("MaxPoolV2")
877 .Attr(
878 "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
879 "uint16, qint8} = DT_FLOAT")
880 .Attr(GetPaddingAttrString())
881 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
882 .Input("input: T")
883 .Input("ksize: int32")
884 .Input("strides: int32")
885 .Output("output: T")
886 .SetShapeFn([](InferenceContext* c) {
887 TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3));
888 return OkStatus();
889 });
890
891REGISTER_OP("MaxPoolGrad")
892 .Attr("ksize: list(int) >= 4")
893 .Attr("strides: list(int) >= 4")
894 .Attr(GetPaddingAttrStringWithExplicit())
895 .Attr(GetExplicitPaddingsAttrString())
896 .Attr(GetConvnetDataFormatAttrString())
897 .Input("orig_input: T")
898 .Input("orig_output: T")
899 .Input("grad: T")
900 .Output("output: T")
901 .Attr("T: realnumbertype = DT_FLOAT")
902 .SetShapeFn(shape_inference::MaxPoolGradShape);
903
904REGISTER_OP("MaxPoolGradV2")
905 .Attr(GetPaddingAttrString())
906 .Attr(GetConvnetDataFormatAttrString())
907 .Input("orig_input: T")
908 .Input("orig_output: T")
909 .Input("grad: T")
910 .Input("ksize: int32")
911 .Input("strides: int32")
912 .Output("output: T")
913 .Attr("T: realnumbertype = DT_FLOAT")
914 .SetShapeFn(shape_inference::MaxPoolGradShape);
915
916// TODO(b/150813181): Implement explicit padding.
917REGISTER_OP("MaxPoolGradGrad")
918 .Attr("ksize: list(int) >= 4")
919 .Attr("strides: list(int) >= 4")
920 .Attr(GetPaddingAttrString())
921 .Attr(GetConvnetDataFormatAttrString())
922 .Input("orig_input: T")
923 .Input("orig_output: T")
924 .Input("grad: T")
925 .Output("output: T")
926 .Attr("T: realnumbertype")
927 .SetShapeFn([](InferenceContext* c) {
928 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
929 ShapeHandle unused;
930 // Validate 'orig_input' is the same shape as 'grad'
931 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
932 // Validate 'orig_output' is same shape as 'output'
933 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
934 return OkStatus();
935 });
936
937REGISTER_OP("MaxPoolGradGradV2")
938 .Attr(GetPaddingAttrString())
939 .Attr(GetConvnetDataFormatAttrString())
940 .Input("orig_input: T")
941 .Input("orig_output: T")
942 .Input("grad: T")
943 .Input("ksize: int32")
944 .Input("strides: int32")
945 .Output("output: T")
946 .Attr("T: realnumbertype")
947 .SetShapeFn([](InferenceContext* c) {
948 TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5));
949 ShapeHandle unused;
950 // Validate 'orig_input' is the same shape as 'grad'
951 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
952 // Validate 'orig_output' is same shape as 'output'
953 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
954 return OkStatus();
955 });
956
957REGISTER_OP("MaxPoolWithArgmax")
958 .Attr("ksize: list(int) >= 4")
959 .Attr("strides: list(int) >= 4")
960 .Attr("Targmax: {int32, int64} = DT_INT64")
961 .Attr(GetPaddingAttrString())
962 .Attr("include_batch_in_index: bool = false")
963 .Input("input: T")
964 .Output("output: T")
965 .Output("argmax: Targmax")
966 .Attr("T: realnumbertype")
967 .SetShapeFn([](InferenceContext* c) {
968 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
969 c->set_output(1, c->output(0));
970 return OkStatus();
971 });
972
973REGISTER_OP("MaxPoolGradWithArgmax")
974 .Attr("ksize: list(int) >= 4")
975 .Attr("strides: list(int) >= 4")
976 .Attr(GetPaddingAttrString())
977 .Attr("include_batch_in_index: bool = false")
978 .Attr("Targmax: {int32, int64}")
979 .Input("input: T")
980 .Input("grad: T")
981 .Input("argmax: Targmax")
982 .Output("output: T")
983 .Attr("T: realnumbertype")
984 .SetShapeFn([](InferenceContext* c) {
985 return UnchangedShapeWithRank(c, 4);
986 });
987
988REGISTER_OP("MaxPoolGradGradWithArgmax")
989 .Attr("ksize: list(int) >= 4")
990 .Attr("strides: list(int) >= 4")
991 .Attr(GetPaddingAttrString())
992 .Attr("include_batch_in_index: bool = false")
993 .Attr("Targmax: {int32, int64}")
994 .Input("input: T")
995 .Input("grad: T")
996 .Input("argmax: Targmax")
997 .Output("output: T")
998 .Attr("T: realnumbertype")
999 .SetShapeFn([](InferenceContext* c) {
1000 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
1001 ShapeHandle unused;
1002 // Validate 'orig_input' is the same shape as 'grad'
1003 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused));
1004 // Validate 'argmax' is same shape as 'output'
1005 TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused));
1006 return OkStatus();
1007 });
1008
1009// --------------------------------------------------------------------------
1010
1011REGISTER_OP("Dilation2D")
1012 .Input("input: T")
1013 .Input("filter: T")
1014 .Output("output: T")
1015 .Attr("T: realnumbertype")
1016 .Attr("strides: list(int) >= 4")
1017 .Attr("rates: list(int) >= 4")
1018 .Attr(GetPaddingAttrString())
1019 .SetShapeFn([](InferenceContext* c) {
1020 ShapeHandle input_shape;
1021 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1022 ShapeHandle filter_shape;
1023 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape));
1024
1025 std::vector<int32> strides;
1026 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
1027 if (strides.size() != 4) {
1028 return errors::InvalidArgument(
1029 "Dilation2D requires the stride attribute to contain 4 values, but "
1030 "got: ",
1031 strides.size());
1032 }
1033
1034 std::vector<int32> rates;
1035 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
1036 if (rates.size() != 4) {
1037 return errors::InvalidArgument(
1038 "Dilation2D requires the rates attribute to contain 4 values, but "
1039 "got: ",
1040 rates.size());
1041 }
1042
1043 int32_t stride_rows = strides[1];
1044 int32_t stride_cols = strides[2];
1045
1046 int32_t rate_rows = rates[1];
1047 int32_t rate_cols = rates[2];
1048
1049 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
1050 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
1051 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
1052 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
1053 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
1054 DimensionHandle output_depth_dim = c->Dim(filter_shape, 2);
1055
1056 if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim) ||
1057 !c->ValueKnown(filter_rows_dim) || !c->ValueKnown(filter_cols_dim)) {
1058 ShapeHandle output_shape =
1059 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
1060 InferenceContext::kUnknownDim, output_depth_dim});
1061 c->set_output(0, output_shape);
1062 return OkStatus();
1063 }
1064 DimensionHandle unused;
1065 TF_RETURN_IF_ERROR(
1066 c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
1067
1068 auto in_rows = c->Value(in_rows_dim);
1069 auto in_cols = c->Value(in_cols_dim);
1070 auto filter_rows = c->Value(filter_rows_dim);
1071 auto filter_cols = c->Value(filter_cols_dim);
1072 auto filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_rows - 1);
1073 auto filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_cols - 1);
1074
1075 Padding padding;
1076 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
1077
1078 int64_t output_rows, output_cols;
1079 int64_t padding_before, padding_after;
1080 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
1081 in_rows, filter_rows_eff, stride_rows, padding, &output_rows,
1082 &padding_before, &padding_after));
1083 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
1084 in_cols, filter_cols_eff, stride_cols, padding, &output_cols,
1085 &padding_before, &padding_after));
1086
1087 ShapeHandle output_shape = c->MakeShape(
1088 {batch_size_dim, output_rows, output_cols, output_depth_dim});
1089 c->set_output(0, output_shape);
1090 return OkStatus();
1091 });
1092
1093REGISTER_OP("Dilation2DBackpropInput")
1094 .Input("input: T")
1095 .Input("filter: T")
1096 .Input("out_backprop: T")
1097 .Output("in_backprop: T")
1098 .Attr("T: realnumbertype")
1099 .Attr("strides: list(int) >= 4")
1100 .Attr("rates: list(int) >= 4")
1101 .Attr(GetPaddingAttrString())
1102 .SetShapeFn(shape_inference::UnchangedShape);
1103
1104REGISTER_OP("Dilation2DBackpropFilter")
1105 .Input("input: T")
1106 .Input("filter: T")
1107 .Input("out_backprop: T")
1108 .Output("filter_backprop: T")
1109 .Attr("T: realnumbertype")
1110 .Attr("strides: list(int) >= 4")
1111 .Attr("rates: list(int) >= 4")
1112 .Attr(GetPaddingAttrString())
1113 .SetShapeFn([](InferenceContext* c) {
1114 c->set_output(0, c->input(1));
1115 return OkStatus();
1116 });
1117
1118// --------------------------------------------------------------------------
1119
1120REGISTER_OP("Relu")
1121 .Input("features: T")
1122 .Output("activations: T")
1123 .Attr("T: {realnumbertype, qint8}")
1124 .SetShapeFn(shape_inference::UnchangedShape);
1125
1126REGISTER_OP("ReluGrad")
1127 .Input("gradients: T")
1128 .Input("features: T")
1129 .Output("backprops: T")
1130 .Attr("T: realnumbertype")
1131 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1132
1133REGISTER_OP("Relu6")
1134 .Input("features: T")
1135 .Output("activations: T")
1136 .Attr("T: realnumbertype")
1137 .SetShapeFn(shape_inference::UnchangedShape);
1138
1139REGISTER_OP("Relu6Grad")
1140 .Input("gradients: T")
1141 .Input("features: T")
1142 .Output("backprops: T")
1143 .Attr("T: realnumbertype")
1144 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1145
1146REGISTER_OP("LeakyRelu")
1147 .Input("features: T")
1148 .Output("activations: T")
1149 .Attr("alpha: float = 0.2")
1150 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1151 .SetShapeFn(shape_inference::UnchangedShape);
1152
1153REGISTER_OP("LeakyReluGrad")
1154 .Input("gradients: T")
1155 .Input("features: T")
1156 .Output("backprops: T")
1157 .Attr("alpha: float = 0.2")
1158 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1159 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1160
1161REGISTER_OP("Elu")
1162 .Input("features: T")
1163 .Output("activations: T")
1164 .Attr("T: {half, bfloat16, float, double}")
1165 .SetShapeFn(shape_inference::UnchangedShape);
1166
1167REGISTER_OP("EluGrad")
1168 .Input("gradients: T")
1169 .Input("outputs: T")
1170 .Output("backprops: T")
1171 .Attr("T: {half, bfloat16, float, double}")
1172 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1173
1174REGISTER_OP("Selu")
1175 .Input("features: T")
1176 .Output("activations: T")
1177 .Attr("T: {half, bfloat16, float, double}")
1178 .SetShapeFn(shape_inference::UnchangedShape);
1179
1180REGISTER_OP("SeluGrad")
1181 .Input("gradients: T")
1182 .Input("outputs: T")
1183 .Output("backprops: T")
1184 .Attr("T: {half, bfloat16, float, double}")
1185 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1186
1187REGISTER_OP("Softplus")
1188 .Input("features: T")
1189 .Output("activations: T")
1190 .Attr("T: {half, bfloat16, float, double}")
1191 .SetShapeFn(shape_inference::UnchangedShape);
1192
1193REGISTER_OP("SoftplusGrad")
1194 .Input("gradients: T")
1195 .Input("features: T")
1196 .Output("backprops: T")
1197 .Attr("T: {half, bfloat16, float, double}")
1198 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1199
1200REGISTER_OP("Softsign")
1201 .Input("features: T")
1202 .Output("activations: T")
1203 .Attr("T: {half, bfloat16, float, double}")
1204 .SetShapeFn(shape_inference::UnchangedShape);
1205
1206REGISTER_OP("SoftsignGrad")
1207 .Input("gradients: T")
1208 .Input("features: T")
1209 .Output("backprops: T")
1210 .Attr("T: {half, bfloat16, float, double}")
1211 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1212
1213// --------------------------------------------------------------------------
1214
1215REGISTER_OP("Softmax")
1216 .Input("logits: T")
1217 .Output("softmax: T")
1218 .Attr("T: {half, bfloat16, float, double}")
1219 .SetShapeFn([](InferenceContext* c) {
1220 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1221 });
1222
1223// --------------------------------------------------------------------------
1224
1225REGISTER_OP("LogSoftmax")
1226 .Input("logits: T")
1227 .Output("logsoftmax: T")
1228 .Attr("T: {half, bfloat16, float, double}")
1229 .SetShapeFn([](InferenceContext* c) {
1230 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1231 });
1232
1233// --------------------------------------------------------------------------
1234
1235REGISTER_OP("SoftmaxCrossEntropyWithLogits")
1236 .Input("features: T")
1237 .Input("labels: T")
1238 .Output("loss: T")
1239 .Output("backprop: T")
1240 .Attr("T: {half, bfloat16, float, double}")
1241 .SetShapeFn([](InferenceContext* c) {
1242 ShapeHandle input;
1243 if (c->WithRank(c->input(0), 2, &input) == OkStatus() &&
1244 c->Merge(input, c->input(1), &input) == OkStatus()) {
1245 DimensionHandle batch_size = c->Dim(input, 0);
1246 c->set_output(0, c->Vector(batch_size));
1247 c->set_output(1, input);
1248 return OkStatus();
1249 }
1250 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1));
1251
1252 if (!c->RankKnown(c->output(1))) {
1253 return errors::InvalidArgument(
1254 "Shape must be broadcasted with rank 2, but is rank is unknown.");
1255 }
1256
1257 if (c->Rank(c->output(1)) != 2) {
1258 return errors::InvalidArgument(
1259 "Shape must be broadcasted with rank 2, but is rank ",
1260 c->Rank(c->output(1)));
1261 }
1262 DimensionHandle batch_size = c->Dim(c->output(1), 0);
1263 c->set_output(0, c->Vector(batch_size));
1264 return OkStatus();
1265 });
1266
1267REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
1268 .Input("features: T")
1269 .Input("labels: Tlabels")
1270 .Output("loss: T")
1271 .Output("backprop: T")
1272 .Attr("T: {half, bfloat16, float, double}")
1273 .Attr("Tlabels: {int32, int64} = DT_INT64")
1274 .SetShapeFn([](InferenceContext* c) {
1275 ShapeHandle features;
1276 ShapeHandle labels;
1277 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features));
1278 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels));
1279
1280 DimensionHandle batch_size;
1281 TF_RETURN_IF_ERROR(
1282 c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size));
1283 TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features));
1284
1285 c->set_output(0, c->Vector(batch_size));
1286 c->set_output(1, features);
1287 return OkStatus();
1288 });
1289
1290// --------------------------------------------------------------------------
1291
1292REGISTER_OP("InTopK")
1293 .Input("predictions: float")
1294 .Input("targets: T")
1295 .Output("precision: bool")
1296 .Attr("k: int")
1297 .Attr("T: {int32, int64} = DT_INT32")
1298 .SetShapeFn([](InferenceContext* c) {
1299 ShapeHandle predictions;
1300 ShapeHandle targets;
1301 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1302 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1303 DimensionHandle batch_size;
1304 TF_RETURN_IF_ERROR(
1305 c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1306 c->set_output(0, c->Vector(batch_size));
1307 return OkStatus();
1308 });
1309
1310// This is the same as `InTopK`, but takes `k` as in input rather than an attr.
1311REGISTER_OP("InTopKV2")
1312 .Input("predictions: float")
1313 .Input("targets: T")
1314 .Input("k: T")
1315 .Output("precision: bool")
1316 .Attr("T: {int32, int64} = DT_INT32")
1317 .SetShapeFn([](InferenceContext* c) {
1318 ShapeHandle predictions;
1319 ShapeHandle targets;
1320 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1321 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1322 DimensionHandle batch_size;
1323 TF_RETURN_IF_ERROR(
1324 c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1325 c->set_output(0, c->Vector(batch_size));
1326 return OkStatus();
1327 });
1328
1329namespace {
1330
1331Status TopKShapeFn(InferenceContext* c) {
1332 ShapeHandle input;
1333 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1334
1335 // Get the k value, either from input tensor or attribute.
1336 DimensionHandle k_dim;
1337 if (c->num_inputs() >= 2) {
1338 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim));
1339 } else {
1340 int32_t k;
1341 TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
1342 if (k < 0) {
1343 return errors::InvalidArgument("Need k >= 0, got ", k);
1344 }
1345 k_dim = c->MakeDim(k);
1346 }
1347
1348 DimensionHandle last_dim = c->Dim(input, -1);
1349 if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) &&
1350 c->Value(last_dim) < c->Value(k_dim)) {
1351 return errors::InvalidArgument(
1352 "input must have last dimension >= k = ", c->Value(k_dim), " but is ",
1353 c->Value(last_dim));
1354 }
1355
1356 // Replace last_dim with k_dim.
1357 ShapeHandle s;
1358 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1359 TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s));
1360 c->set_output(0, s);
1361 c->set_output(1, s);
1362 return OkStatus();
1363}
1364
1365// Utility functions for ApproxTopKShape.
1366// It is not easy to link xla/client/lib into the tensorflow core lib, so we
1367// have to replicate the logic.
1368// LINT.IfChange
1369inline uint32_t log2_floor(uint64_t value) {
1370 return value == 0 ? 0 : Log2Floor(value);
1371}
1372
1373inline uint32_t log2_ceil(uint64_t value) {
1374 return value == 0 ? 0 : Log2Ceiling(value);
1375}
1376
1377Status ApproxTopKShape(shape_inference::InferenceContext* c) {
1378 int64_t k;
1379 int64_t reduction_dimension;
1380 float recall_target;
1381 int64_t reduction_input_size_override;
1382 bool aggregate_to_topk;
1383 TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
1384 TF_RETURN_IF_ERROR(c->GetAttr("reduction_dimension", &reduction_dimension));
1385 TF_RETURN_IF_ERROR(c->GetAttr("recall_target", &recall_target));
1386 TF_RETURN_IF_ERROR(c->GetAttr("reduction_input_size_override",
1387 &reduction_input_size_override));
1388 TF_RETURN_IF_ERROR(c->GetAttr("aggregate_to_topk", &aggregate_to_topk));
1389 ShapeHandle input_shape;
1390 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape));
1391 if (reduction_dimension < 0) {
1392 // Reverse index
1393 reduction_dimension += c->Rank(input_shape);
1394 }
1395 int64_t reduction_dim_value =
1396 c->Value(c->Dim(input_shape, reduction_dimension));
1397
1398 if (reduction_dim_value < k) {
1399 return errors::InvalidArgument("input must have last dimension >= k = ", k,
1400 " but was ", reduction_dim_value);
1401 }
1402
1403 int64_t output_dim_value = [&] {
1404 if (aggregate_to_topk) {
1405 return k;
1406 }
1407 int64_t tpu_tiling = c->Rank(input_shape) == 1 ? 1024 : 128;
1408 if (reduction_dim_value <= tpu_tiling || recall_target == 1.0) {
1409 return reduction_dim_value;
1410 }
1411 if (k == 1) {
1412 return tpu_tiling;
1413 }
1414 uint64_t logical_input_size = reduction_input_size_override >= 0
1415 ? reduction_input_size_override
1416 : reduction_dim_value;
1417 uint64_t m = std::min<uint64_t>(
1418 std::max<uint64_t>(
1419 static_cast<uint64_t>((1.0 - k) /
1420 std::log(static_cast<double>(recall_target))),
1421 tpu_tiling),
1422 reduction_dim_value);
1423 uint32_t log2_reduction = log2_floor(logical_input_size / m);
1424 if (log2_reduction == 0) {
1425 return reduction_dim_value;
1426 }
1427 log2_reduction = std::min<uint32_t>(
1428 log2_reduction, log2_ceil(reduction_dim_value / tpu_tiling));
1429 return tensorflow::MathUtil::CeilOfRatio<int64_t>(
1430 tensorflow::MathUtil::CeilOfRatio<int64_t>(reduction_dim_value,
1431 tpu_tiling),
1432 (1 << log2_reduction)) *
1433 tpu_tiling;
1434 }();
1435
1436 auto output_dim = c->MakeDim(output_dim_value);
1437
1438 ShapeHandle output_shape;
1439 TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, reduction_dimension, output_dim,
1440 &output_shape));
1441 c->set_output(0, output_shape);
1442 c->set_output(1, output_shape);
1443 return OkStatus();
1444}
1445// LINT.ThenChange(//tensorflow/compiler/xla/client/lib/approx_topk_shape.cc)
1446
1447} // namespace
1448
1449REGISTER_OP("TopK")
1450 .Input("input: T")
1451 .Output("values: T")
1452 .Output("indices: int32")
1453 .Attr("k: int >= 0")
1454 .Attr("sorted: bool = true")
1455 .Attr("T: realnumbertype")
1456 .Deprecated(7, "Use TopKV2 instead")
1457 .SetShapeFn(TopKShapeFn);
1458
1459// This is the same as `TopK`, but takes `k` as in input rather than an attr.
1460REGISTER_OP("TopKV2")
1461 .Input("input: T")
1462 .Input("k: int32")
1463 .Output("values: T")
1464 .Output("indices: int32")
1465 .Attr("sorted: bool = true")
1466 .Attr("T: realnumbertype")
1467 .SetShapeFn(TopKShapeFn);
1468
1469REGISTER_OP("ApproxTopK")
1470 .Input("input: T")
1471 .Output("values: T")
1472 .Output("indices: int32")
1473 .Attr("k: int >= 0")
1474 .Attr("reduction_dimension: int = -1")
1475 .Attr("recall_target: float = 0.95")
1476 .Attr("is_max_k: bool = true")
1477 .Attr("reduction_input_size_override: int = -1")
1478 .Attr("aggregate_to_topk: bool = true")
1479 .Attr("T: {half, bfloat16, float}")
1480 .SetShapeFn(ApproxTopKShape);
1481
1482// --------------------------------------------------------------------------
1483
1484REGISTER_OP("NthElement")
1485 .Input("input: T")
1486 .Input("n: int32")
1487 .Output("values: T")
1488 .Attr("reverse: bool = false")
1489 .Attr("T: realnumbertype")
1490 .SetShapeFn([](InferenceContext* c) {
1491 ShapeHandle input;
1492 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1493
1494 // Get the n value from input tensor, and make sure which is a scalar.
1495 DimensionHandle n_dim;
1496 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &n_dim));
1497
1498 // The last dimension of input tensor must be greater than N.
1499 DimensionHandle last_dim = c->Dim(input, -1);
1500 if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) &&
1501 c->Value(last_dim) <= c->Value(n_dim)) {
1502 return errors::InvalidArgument(
1503 "Input must have last dimension > n = ", c->Value(n_dim),
1504 " but is ", c->Value(last_dim));
1505 }
1506
1507 // Reduce last_dim for output tensor
1508 ShapeHandle s;
1509 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1510 c->set_output(0, s);
1511 return OkStatus();
1512 });
1513
1514// --------------------------------------------------------------------------
1515
1516REGISTER_OP("FractionalMaxPool")
1517 .Input("value: T")
1518 .Output("output: T")
1519 .Output("row_pooling_sequence: int64")
1520 .Output("col_pooling_sequence: int64")
1521 .Attr("pooling_ratio: list(float) >=4")
1522 .Attr("pseudo_random: bool = false")
1523 .Attr("overlapping: bool = false")
1524 .Attr("deterministic: bool = false")
1525 .Attr("seed: int = 0")
1526 .Attr("seed2: int = 0")
1527 .Attr("T: {float, double, int32, int64}")
1528 .SetShapeFn(FractionalPoolShapeFn);
1529
1530REGISTER_OP("FractionalMaxPoolGrad")
1531 .Input("orig_input: T")
1532 .Input("orig_output: T")
1533 .Input("out_backprop: T")
1534 .Input("row_pooling_sequence: int64")
1535 .Input("col_pooling_sequence: int64")
1536 .Output("output: T")
1537 .Attr("overlapping: bool = false")
1538 .Attr("T: {float, double, int32, int64}")
1539 .SetShapeFn([](InferenceContext* c) {
1540 return shape_inference::UnchangedShapeWithRank(c, 4);
1541 });
1542
1543// --------------------------------------------------------------------------
1544
1545REGISTER_OP("FractionalAvgPool")
1546 .Input("value: T")
1547 .Output("output: T")
1548 .Output("row_pooling_sequence: int64")
1549 .Output("col_pooling_sequence: int64")
1550 .Attr("pooling_ratio: list(float) >=4")
1551 .Attr("pseudo_random: bool = false")
1552 .Attr("overlapping: bool = false")
1553 .Attr("deterministic: bool = false")
1554 .Attr("seed: int = 0")
1555 .Attr("seed2: int = 0")
1556 .Attr("T: {float, double, int32, int64}")
1557 .SetShapeFn(FractionalPoolShapeFn);
1558
1559REGISTER_OP("FractionalAvgPoolGrad")
1560 .Input("orig_input_tensor_shape: int64")
1561 .Input("out_backprop: T")
1562 .Input("row_pooling_sequence: int64")
1563 .Input("col_pooling_sequence: int64")
1564 .Output("output: T")
1565 .Attr("overlapping: bool = false")
1566 .Attr("T: {float, double, int32, int64}")
1567 .SetShapeFn([](InferenceContext* c) {
1568 if (c->input_tensor(0) != nullptr) {
1569 ShapeHandle out;
1570 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1571 c->set_output(0, out);
1572 } else {
1573 c->set_output(0, c->UnknownShapeOfRank(4));
1574 }
1575 return OkStatus();
1576 });
1577
1578REGISTER_OP("QuantizedAvgPool")
1579 .Input("input: T")
1580 .Input("min_input: float")
1581 .Input("max_input: float")
1582 .Output("output: T")
1583 .Output("min_output: float")
1584 .Output("max_output: float")
1585 .Attr("T: quantizedtype")
1586 .Attr("ksize: list(int)")
1587 .Attr("strides: list(int)")
1588 .Attr(GetPaddingAttrString())
1589 .SetShapeFn(shape_inference::QuantizedAvgPoolShape);
1590
1591REGISTER_OP("QuantizedBiasAdd")
1592 .Input("input: T1")
1593 .Input("bias: T2")
1594 .Input("min_input: float")
1595 .Input("max_input: float")
1596 .Input("min_bias: float")
1597 .Input("max_bias: float")
1598 .Output("output: out_type")
1599 .Output("min_out: float")
1600 .Output("max_out: float")
1601 .Attr("T1: quantizedtype")
1602 .Attr("T2: quantizedtype")
1603 .Attr("out_type: quantizedtype")
1604 .SetShapeFn([](InferenceContext* c) {
1605 TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c));
1606 ShapeHandle unused;
1607 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1608 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1609 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1610 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1611 c->set_output(1, c->Scalar());
1612 c->set_output(2, c->Scalar());
1613 return OkStatus();
1614 });
1615
1616REGISTER_OP("QuantizedConv2D")
1617 .Input("input: Tinput")
1618 .Input("filter: Tfilter")
1619 .Input("min_input: float")
1620 .Input("max_input: float")
1621 .Input("min_filter: float")
1622 .Input("max_filter: float")
1623 .Output("output: out_type")
1624 .Output("min_output: float")
1625 .Output("max_output: float")
1626 .Attr("Tinput: quantizedtype")
1627 .Attr("Tfilter: quantizedtype")
1628 .Attr("out_type: quantizedtype = DT_QINT32")
1629 .Attr("strides: list(int)")
1630 .Attr(GetPaddingAttrString())
1631 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1632 .SetShapeFn(shape_inference::QuantizedConv2DShape);
1633
1634REGISTER_OP("QuantizedMaxPool")
1635 .Input("input: T")
1636 .Input("min_input: float")
1637 .Input("max_input: float")
1638 .Output("output: T")
1639 .Output("min_output: float")
1640 .Output("max_output: float")
1641 .Attr("T: quantizedtype")
1642 .Attr("ksize: list(int)")
1643 .Attr("strides: list(int)")
1644 .Attr(GetPaddingAttrString())
1645 .SetShapeFn([](InferenceContext* c) {
1646 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
1647 ShapeHandle unused;
1648 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1649 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1650 c->set_output(1, c->Scalar());
1651 c->set_output(2, c->Scalar());
1652 return OkStatus();
1653 });
1654
1655REGISTER_OP("QuantizedRelu")
1656 .Input("features: Tinput")
1657 .Input("min_features: float")
1658 .Input("max_features: float")
1659 .Output("activations: out_type")
1660 .Output("min_activations: float")
1661 .Output("max_activations: float")
1662 .Attr("Tinput: quantizedtype")
1663 .Attr("out_type: quantizedtype = DT_QUINT8")
1664 .SetShapeFn([](InferenceContext* c) {
1665 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1666 ShapeHandle unused;
1667 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1668 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1669 c->set_output(1, c->Scalar());
1670 c->set_output(2, c->Scalar());
1671 return OkStatus();
1672 });
1673
1674REGISTER_OP("QuantizedRelu6")
1675 .Input("features: Tinput")
1676 .Input("min_features: float")
1677 .Input("max_features: float")
1678 .Output("activations: out_type")
1679 .Output("min_activations: float")
1680 .Output("max_activations: float")
1681 .Attr("Tinput: quantizedtype")
1682 .Attr("out_type: quantizedtype = DT_QUINT8")
1683 .SetShapeFn([](InferenceContext* c) {
1684 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1685 ShapeHandle unused;
1686 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1687 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1688 c->set_output(1, c->Scalar());
1689 c->set_output(2, c->Scalar());
1690 return OkStatus();
1691 });
1692
1693REGISTER_OP("QuantizedReluX")
1694 .Input("features: Tinput")
1695 .Input("max_value: float")
1696 .Input("min_features: float")
1697 .Input("max_features: float")
1698 .Output("activations: out_type")
1699 .Output("min_activations: float")
1700 .Output("max_activations: float")
1701 .Attr("Tinput: quantizedtype")
1702 .Attr("out_type: quantizedtype = DT_QUINT8")
1703 .SetShapeFn([](InferenceContext* c) {
1704 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1705 ShapeHandle unused;
1706 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1707 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1708 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1709 c->set_output(1, c->Scalar());
1710 c->set_output(2, c->Scalar());
1711 return OkStatus();
1712 });
1713
1714REGISTER_OP("QuantizedBatchNormWithGlobalNormalization")
1715 .Input("t: Tinput")
1716 .Input("t_min: float")
1717 .Input("t_max: float")
1718 .Input("m: Tinput")
1719 .Input("m_min: float")
1720 .Input("m_max: float")
1721 .Input("v: Tinput")
1722 .Input("v_min: float")
1723 .Input("v_max: float")
1724 .Input("beta: Tinput")
1725 .Input("beta_min: float")
1726 .Input("beta_max: float")
1727 .Input("gamma: Tinput")
1728 .Input("gamma_min: float")
1729 .Input("gamma_max: float")
1730 .Output("result: out_type")
1731 .Output("result_min: float")
1732 .Output("result_max: float")
1733 .Attr("Tinput: quantizedtype")
1734 .Attr("out_type: quantizedtype")
1735 .Attr("variance_epsilon: float")
1736 .Attr("scale_after_normalization: bool")
1737 .SetShapeFn([](InferenceContext* c) {
1738 ShapeHandle input;
1739 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1740
1741 DimensionHandle last_dim = c->Dim(input, 3);
1742 for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
1743 ShapeHandle vec;
1744 TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec));
1745 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
1746 }
1747
1748 ShapeHandle out;
1749 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
1750 c->set_output(0, out);
1751 c->set_output(1, c->Scalar());
1752 c->set_output(2, c->Scalar());
1753
1754 return OkStatus();
1755 });
1756
1757#ifdef INTEL_MKL
1758REGISTER_OP("_MklDepthwiseConv2dNative")
1759 .Input("input: T")
1760 .Input("filter: T")
1761 .Input("mkl_input: uint8")
1762 .Input("mkl_filter: uint8")
1763 .Output("output: T")
1764 .Output("filter_output: T")
1765 .Output("mkl_output: uint8")
1766 .Output("mkl_filter_output: uint8")
1767 .Attr("T: {half, bfloat16, float, double}")
1768 .Attr("strides: list(int)")
1769 .Attr("is_filter_const: bool = false")
1770 .Attr(GetPaddingAttrStringWithExplicit())
1771 .Attr(GetConvnetDataFormatAttrString())
1772 .Attr(GetExplicitPaddingsAttrString())
1773 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1774 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding);
1775
1776REGISTER_OP("_MklConv2D")
1777 .Input("input: T")
1778 .Input("filter: T")
1779 .Input("mkl_input: uint8")
1780 .Input("mkl_filter: uint8")
1781 .Output("output: T")
1782 .Output("filter_output: T")
1783 .Output("mkl_output: uint8")
1784 .Output("mkl_filter_output: uint8")
1785 .Attr("T: {bfloat16, float}")
1786 .Attr("strides: list(int)")
1787 .Attr("use_cudnn_on_gpu: bool = true")
1788 .Attr("is_filter_const: bool = false")
1789 .Attr(GetPaddingAttrStringWithExplicit())
1790 .Attr(GetConvnetDataFormatAttrString())
1791 .Attr(GetExplicitPaddingsAttrString())
1792 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1793 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1794 .Doc(R"doc(
1795MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution.
1796
1797NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1798expected to invoke these operators.
1799)doc");
1800
1801REGISTER_OP("_MklNativeConv2D")
1802 .Input("input: T")
1803 .Input("filter: T")
1804 .Output("output: T")
1805 .Attr("T: {bfloat16, float}")
1806 .Attr("strides: list(int)")
1807 .Attr("use_cudnn_on_gpu: bool = true")
1808 .Attr("is_filter_const: bool = false")
1809 .Attr(GetPaddingAttrStringWithExplicit())
1810 .Attr(GetExplicitPaddingsAttrString())
1811 .Attr(GetConvnetDataFormatAttrString())
1812 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1813 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1814 .Doc(R"doc(
1815 MKL version of Conv2D operator for Eager mode. Uses MKL DNN APIs to perform 2D convolution.
1816
1817 NOTE Do not invoke this operator directly in Python. Eager Op rewrite is
1818 expected to invoke these operators.
1819 )doc");
1820
1821REGISTER_OP("__MklDummyConv2DWithBias")
1822 .Input("input: T")
1823 .Input("filter: T")
1824 .Input("bias: T")
1825 .Output("output: T")
1826 .Attr("T: {bfloat16, float}")
1827 .Attr("strides: list(int)")
1828 .Attr("use_cudnn_on_gpu: bool = true")
1829 .Attr("is_filter_const: bool = false")
1830 .Attr(GetPaddingAttrStringWithExplicit())
1831 .Attr(GetExplicitPaddingsAttrString())
1832 .Attr(GetConvnetDataFormatAttrString())
1833 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1834 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1835 .Doc(R"doc(
1836Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node
1837does not perform anything. It is just created as an intermediate output of
1838merging Conv2D and BiasAdd.
1839
1840NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1841expected to invoke these operators.
1842)doc");
1843
1844REGISTER_OP("_MklConv2DWithBias")
1845 .Input("input: T")
1846 .Input("filter: T")
1847 .Input("bias: T")
1848 .Input("mkl_input: uint8")
1849 .Input("mkl_filter: uint8")
1850 .Input("mkl_bias: uint8")
1851 .Output("output: T")
1852 .Output("filter_output: T")
1853 .Output("mkl_output: uint8")
1854 .Output("mkl_filter_output: uint8")
1855 .Attr("T: {bfloat16, float}")
1856 .Attr("strides: list(int)")
1857 .Attr("use_cudnn_on_gpu: bool = true")
1858 .Attr("is_filter_const: bool = false")
1859 .Attr(GetPaddingAttrStringWithExplicit())
1860 .Attr(GetExplicitPaddingsAttrString())
1861 .Attr(GetConvnetDataFormatAttrString())
1862 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1863 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding)
1864 .Doc(R"doc(
1865MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform
18662D convolution and add Bias to the output of convolution.
1867
1868NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1869expected to invoke these operators.
1870)doc");
1871
1872REGISTER_OP("__MklDummyPadWithConv2D")
1873 .Input("input: T")
1874 .Input("filter: T")
1875 .Input("paddings: Tpaddings")
1876 .Output("output: T")
1877 .Attr("T: {bfloat16, float}")
1878 .Attr("strides: list(int)")
1879 .Attr("use_cudnn_on_gpu: bool = true")
1880 .Attr(GetPaddingAttrString())
1881 .Attr(GetConvnetDataFormatAttrString())
1882 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1883 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1884 .SetShapeFn(shape_inference::Conv2DShape)
1885 .Doc(R"doc(
1886Dummy node that enables fusing Pad and Conv2D operator for MKL. This node
1887does not perform anything. It is just created as an intermediate output of
1888merging Pad and Conv2D.
1889
1890NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1891expected to invoke these operators.
1892)doc");
1893
1894REGISTER_OP("_MklPadWithConv2D")
1895 .Input("input: T")
1896 .Input("filter: T")
1897 .Input("paddings: Tpaddings")
1898 .Input("mkl_input: uint8")
1899 .Input("mkl_filter: uint8")
1900 .Input("mkl_paddings: uint8")
1901 .Output("output: T")
1902 .Output("filter_output: T")
1903 .Output("mkl_output: uint8")
1904 .Output("mkl_filter_output: uint8")
1905 .Attr("T: {bfloat16, float}")
1906 .Attr("strides: list(int)")
1907 .Attr("use_cudnn_on_gpu: bool = true")
1908 .Attr(GetPaddingAttrString())
1909 .Attr(GetConvnetDataFormatAttrString())
1910 .Attr("is_filter_const: bool = false")
1911 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1912 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1913 .SetShapeFn(shape_inference::Conv2DShape)
1914 .Doc(R"doc(
1915MKL version of Pad and Conv2D operator. Uses MKL DNN APIs to perform
1916Pad and 2D convolution to the output of convolution.
1917
1918NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1919expected to invoke these operators.
1920)doc");
1921
1922REGISTER_OP("_MklConv2DBackpropFilter")
1923 .Input("input: T")
1924 .Input("filter_sizes: int32")
1925 .Input("out_backprop: T")
1926 .Input("mkl_input: uint8")
1927 .Input("mkl_filter_size: uint8")
1928 .Input("mkl_out_backprop: uint8")
1929 .Output("output: T")
1930 .Output("mkl_output: uint8")
1931 .Attr("T: {bfloat16, float}")
1932 .Attr("strides: list(int)")
1933 .Attr("use_cudnn_on_gpu: bool = true")
1934 .Attr(GetPaddingAttrString())
1935 .Attr(GetConvnetDataFormatAttrString())
1936 .Attr(GetExplicitPaddingsAttrString())
1937 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1938 .SetShapeFn([](InferenceContext* c) {
1939 ShapeHandle s;
1940 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1941 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1942 c->set_output(0, s);
1943 return Status::OK();
1944 })
1945 .Doc(R"doc(
1946MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
1947gradients of convolution with respect to the filter.
1948
1949NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1950expected to invoke these operators.
1951)doc");
1952
1953REGISTER_OP("_MklNativeConv2DBackpropFilter")
1954 .Input("input: T")
1955 .Input("filter_sizes: int32")
1956 .Input("out_backprop: T")
1957 .Output("output: T")
1958 .Attr("T: {bfloat16, float}")
1959 .Attr("strides: list(int)")
1960 .Attr("use_cudnn_on_gpu: bool = true")
1961 .Attr(GetPaddingAttrStringWithExplicit())
1962 .Attr(GetExplicitPaddingsAttrString())
1963 .Attr(GetConvnetDataFormatAttrString())
1964 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1965 .SetShapeFn([](InferenceContext* c) {
1966 ShapeHandle s;
1967 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1968 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1969 c->set_output(0, s);
1970 return Status::OK();
1971 })
1972 .Doc(R"doc(
1973MKL version of Conv2DBackpropFilter for Eager mode. Uses MKL DNN APIs
1974to compute the gradients of convolution with respect to the filter.
1975
1976NOTE Do not invoke this operator directly in Python. Eager Op rewrite pass is
1977expected to invoke these operators.
1978)doc");
1979
1980REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias")
1981 .Input("input: T")
1982 .Input("filter_sizes: int32")
1983 .Input("out_backprop: T")
1984 .Output("output: T")
1985 .Output("bias_grad: T")
1986 .Attr("T: {bfloat16, float}")
1987 .Attr("strides: list(int)")
1988 .Attr("use_cudnn_on_gpu: bool = true")
1989 .Attr(GetPaddingAttrString())
1990 .Attr(GetConvnetDataFormatAttrString())
1991 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1992 .SetShapeFn([](InferenceContext* c) {
1993 ShapeHandle input_shape;
1994 // Fetch the data_format attribute, which may not exist.
1995 string data_format;
1996 Status s = c->GetAttr("data_format", &data_format);
1997
1998 if (s.ok() && data_format == "NCHW") {
1999 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2000 c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
2001 } else {
2002 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
2003 c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
2004 }
2005 ShapeHandle sh;
2006 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
2007 TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
2008 c->set_output(0, sh);
2009 return Status::OK();
2010 })
2011 .Doc(R"doc(
2012Dummy node that enables fusing Conv2DBackpropFilter and BiasAddGrad operator
2013for MKL. This node does not perform anything. It is just created as an
2014intermediate output of merging Conv2DBackpropFilter and BiasAddGrad.
2015
2016NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2017expected to invoke these operators.
2018)doc");
2019
2020REGISTER_OP("_MklConv2DBackpropFilterWithBias")
2021 .Input("input: T")
2022 .Input("filter_sizes: int32")
2023 .Input("out_backprop: T")
2024 .Input("mkl_input: uint8")
2025 .Input("mkl_filter_size: uint8")
2026 .Input("mkl_out_backprop: uint8")
2027 .Output("output: T")
2028 .Output("bias_grad: T")
2029 .Output("mkl_output: uint8")
2030 .Output("mkl_bias_grad: uint8")
2031 .Attr("T: {bfloat16, float}")
2032 .Attr("strides: list(int)")
2033 .Attr("use_cudnn_on_gpu: bool = true")
2034 .Attr(GetPaddingAttrString())
2035 .Attr(GetConvnetDataFormatAttrString())
2036 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2037 .SetShapeFn(shape_inference::Conv2DBackpropFilterWithBiasShape)
2038 .Doc(R"doc(
2039MKL version of Conv2DBackpropFilterWithBias. Uses MKL DNN APIs to compute the
2040gradients of convolution with respect to the filter.
2041
2042NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2043expected to invoke these operators.
2044)doc");
2045
2046#ifdef INTEL_MKL_ML_ONLY
2047REGISTER_OP("_MklConv2DWithBiasBackpropBias")
2048 .Input("out_backprop: T")
2049 .Input("mkl_out_backprop: uint8")
2050 .Output("output: T")
2051 .Output("mkl_output: uint8")
2052 .Attr("T: {half, float, double}")
2053 .Attr("strides: list(int)")
2054 .Attr(GetConvnetDataFormatAttrString())
2055 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2056 .Doc(R"doc(
2057MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the
2058gradients of convolution with respect to the bias.
2059
2060NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2061expected to invoke these operators.
2062)doc");
2063#endif
2064
2065REGISTER_OP("_MklConv2DBackpropInput")
2066 .Input("input_sizes: int32")
2067 .Input("filter: T")
2068 .Input("out_backprop: T")
2069 .Input("mkl_input_sizes: uint8")
2070 .Input("mkl_filter: uint8")
2071 .Input("mkl_out_backprop: uint8")
2072 .Output("output: T")
2073 .Output("mkl_output: uint8")
2074 .Attr("T: {bfloat16, float}")
2075 .Attr("strides: list(int)")
2076 .Attr("use_cudnn_on_gpu: bool = true")
2077 .Attr(GetPaddingAttrString())
2078 .Attr(GetConvnetDataFormatAttrString())
2079 .Attr(GetExplicitPaddingsAttrString())
2080 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2081 .SetShapeFn([](InferenceContext* c) {
2082 ShapeHandle s;
2083 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2084 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
2085 c->set_output(0, s);
2086 return Status::OK();
2087 })
2088 .Doc(R"doc(
2089MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
2090gradients of convolution with respect to the input.
2091
2092NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2093expected to invoke these operators.
2094)doc");
2095
2096REGISTER_OP("_MklNativeConv2DBackpropInput")
2097 .Input("input_sizes: int32")
2098 .Input("filter: T")
2099 .Input("out_backprop: T")
2100 .Output("output: T")
2101 .Attr("T: {bfloat16, float}")
2102 .Attr("strides: list(int)")
2103 .Attr("use_cudnn_on_gpu: bool = true")
2104 .Attr(GetPaddingAttrStringWithExplicit())
2105 .Attr(GetExplicitPaddingsAttrString())
2106 .Attr(GetConvnetDataFormatAttrString())
2107 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2108 .SetShapeFn([](InferenceContext* c) {
2109 ShapeHandle s;
2110 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2111 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
2112 c->set_output(0, s);
2113 return Status::OK();
2114 })
2115 .Doc(R"doc(
2116MKL version of Convolution2D backward input for Eager mode. Uses MKL DNN APIs
2117to compute the gradients of convolution with respect to the input.
2118
2119NOTE Do not invoke this operator directly in Python. Eager op rewrite is
2120expected to invoke these operators.
2121)doc");
2122
2123REGISTER_OP("_MklConv3D")
2124 .Input("input: T")
2125 .Input("filter: T")
2126 .Input("mkl_input: uint8")
2127 .Input("mkl_filter: uint8")
2128 .Output("output: T")
2129 .Output("filter_output: T")
2130 .Output("mkl_output: uint8")
2131 .Output("mkl_filter_output: uint8")
2132 .Attr("T: {bfloat16, float}")
2133 .Attr("strides: list(int) >= 5")
2134 .Attr("is_filter_const: bool = false")
2135 .Attr(GetPaddingAttrString())
2136 .Attr(GetConvnet3dDataFormatAttrString())
2137 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
2138 .SetShapeFn(shape_inference::Conv3DShape)
2139 .Doc(R"doc(
2140MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
2141
2142NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2143expected to invoke these operators.
2144)doc");
2145
2146REGISTER_OP("_MklConv3DBackpropInputV2")
2147 .Input("input_sizes: Tshape")
2148 .Input("filter: T")
2149 .Input("out_backprop: T")
2150 .Input("mkl_input_sizes: uint8")
2151 .Input("mkl_filter: uint8")
2152 .Input("mkl_out_backprop: uint8")
2153 .Output("output: T")
2154 .Output("mkl_output: uint8")
2155 .Attr("T: {bfloat16, float}")
2156 .Attr("strides: list(int) >= 5")
2157 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
2158 .Attr("Tshape: {int32, int64} = DT_INT32")
2159 .Attr(GetPaddingAttrString())
2160 .Attr(GetConvnet3dDataFormatAttrString())
2161 .SetShapeFn([](InferenceContext* c) {
2162 ShapeHandle s;
2163 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2164 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
2165 c->set_output(0, s);
2166 return Status::OK();
2167 })
2168 .Doc(R"doc(
2169MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
2170gradients of convolution with respect to the input.
2171
2172NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2173expected to invoke these operators.
2174)doc");
2175
2176REGISTER_OP("_MklConv3DBackpropFilterV2")
2177 .Input("input: T")
2178 .Input("filter_sizes: int32")
2179 .Input("out_backprop: T")
2180 .Input("mkl_input: uint8")
2181 .Input("mkl_filter_size: uint8")
2182 .Input("mkl_out_backprop: uint8")
2183 .Output("output: T")
2184 .Output("mkl_output: uint8")
2185 .Attr("T: {bfloat16, float}")
2186 .Attr("strides: list(int)")
2187 .Attr(GetPaddingAttrString())
2188 .Attr(GetConvnet3dDataFormatAttrString())
2189 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
2190 .SetShapeFn([](InferenceContext* c) {
2191 ShapeHandle s;
2192 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
2193 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
2194 c->set_output(0, s);
2195 return Status::OK();
2196 })
2197 .Doc(R"doc(
2198MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
2199gradients of convolution with respect to the filter.
2200
2201NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2202expected to invoke these operators.
2203)doc");
2204
2205REGISTER_OP("_MklRelu")
2206 .Input("features: T")
2207 .Input("mkl_features: uint8")
2208 .Output("activations: T")
2209 .Output("mkl_activations: uint8")
2210 .Attr("T: {float, bfloat16} = DT_FLOAT")
2211 .SetShapeFn(shape_inference::UnchangedShape)
2212 .Doc(R"doc(
2213MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator.
2214
2215NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2216expected to invoke these operators.
2217)doc");
2218
2219REGISTER_OP("_MklReluGrad")
2220 .Input("gradients: T")
2221 .Input("features: T")
2222 .Input("mkl_gradients: uint8")
2223 .Input("mkl_features: uint8")
2224 .Output("backprops: T")
2225 .Output("mkl_backprops: uint8")
2226 .Attr("T: {float, bfloat16} = DT_FLOAT")
2227 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2228 .Doc(R"doc(
2229MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
2230linear gradients for Relu operation.
2231
2232NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2233expected to invoke these operators.
2234)doc");
2235
2236REGISTER_OP("_MklRelu6")
2237 .Input("features: T")
2238 .Input("mkl_features: uint8")
2239 .Output("activations: T")
2240 .Output("mkl_activations: uint8")
2241 .Attr("T: {float, bfloat16} = DT_FLOAT")
2242 .SetShapeFn(shape_inference::UnchangedShape)
2243 .Doc(R"doc(
2244MKL version of Relu6 operator. Uses MKL DNN APIs to implement Relu6 operator.
2245
2246NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2247expected to invoke these operators.
2248)doc");
2249
2250REGISTER_OP("_MklRelu6Grad")
2251 .Input("gradients: T")
2252 .Input("features: T")
2253 .Input("mkl_gradients: uint8")
2254 .Input("mkl_features: uint8")
2255 .Output("backprops: T")
2256 .Output("mkl_backprops: uint8")
2257 .Attr("T: {float, bfloat16} = DT_FLOAT")
2258 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2259 .Doc(R"doc(
2260MKL version of Relu6Grad operator. Uses MKL DNN APIs to compute rectified
2261linear gradients for Relu6 operation.
2262
2263NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2264expected to invoke these operators.
2265)doc");
2266
2267REGISTER_OP("_MklLeakyRelu")
2268 .Input("features: T")
2269 .Input("mkl_features: uint8")
2270 .Output("activations: T")
2271 .Output("mkl_activations: uint8")
2272 .Attr("T: {float, bfloat16} = DT_FLOAT")
2273 .Attr("alpha: float = 0.2")
2274 .SetShapeFn(shape_inference::UnchangedShape)
2275 .Doc(R"doc(
2276MKL version of LeakyRelu operator. Uses MKL DNN APIs to implement
2277LeakyRelu operator.
2278
2279NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2280expected to invoke these operators.
2281)doc");
2282
2283REGISTER_OP("_MklLeakyReluGrad")
2284 .Input("gradients: T")
2285 .Input("features: T")
2286 .Input("mkl_gradients: uint8")
2287 .Input("mkl_features: uint8")
2288 .Output("backprops: T")
2289 .Output("mkl_backprops: uint8")
2290 .Attr("T: {float, bfloat16} = DT_FLOAT")
2291 .Attr("alpha: float = 0.2")
2292 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2293 .Doc(R"doc(
2294MKL version of LeakyReluGrad operator. Uses MKL DNN APIs to compute rectified
2295linear gradients for LeakyReluGrad operation.
2296
2297NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2298expected to invoke these operators.
2299)doc");
2300
2301REGISTER_OP("_MklElu")
2302 .Input("features: T")
2303 .Input("mkl_features: uint8")
2304 .Output("activations: T")
2305 .Output("mkl_activations: uint8")
2306 .Attr("T: {float, bfloat16} = DT_FLOAT")
2307 .SetShapeFn(shape_inference::UnchangedShape)
2308 .Doc(R"doc(
2309MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator.
2310NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2311expected to invoke these operators.
2312)doc");
2313
2314REGISTER_OP("_MklEluGrad")
2315 .Input("gradients: T")
2316 .Input("features: T")
2317 .Input("mkl_gradients: uint8")
2318 .Input("mkl_features: uint8")
2319 .Output("backprops: T")
2320 .Output("mkl_backprops: uint8")
2321 .Attr("T: {float, bfloat16} = DT_FLOAT")
2322 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2323 .Doc(R"doc(
2324MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu
2325gradients for Elu operation.
2326NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2327expected to invoke these operators.
2328)doc");
2329
2330REGISTER_OP("_MklSoftmax")
2331 .Input("logits: T")
2332 .Input("mkl_logits: uint8")
2333 .Output("softmax: T")
2334 .Output("mkl_softmax: uint8")
2335 .Attr("T: {bfloat16, half, float, double}")
2336 .SetShapeFn([](InferenceContext* c) {
2337 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
2338 })
2339 .Doc(R"doc(
2340MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
2341linear gradients for Relu operation.
2342)doc");
2343
2344REGISTER_OP("_MklTanh")
2345 .Input("features: T")
2346 .Input("mkl_features: uint8")
2347 .Output("activations: T")
2348 .Output("mkl_activations: uint8")
2349 .Attr("T: realnumbertype")
2350 .SetShapeFn(shape_inference::UnchangedShape)
2351 .Doc(R"doc(
2352MKL version of Tanh operator. Uses MKL DNN APIs to implement Tanh operator.
2353NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2354expected to invoke these operators.
2355)doc");
2356
2357REGISTER_OP("_MklTanhGrad")
2358 .Input("gradients: T")
2359 .Input("features: T")
2360 .Input("mkl_gradients: uint8")
2361 .Input("mkl_features: uint8")
2362 .Output("backprops: T")
2363 .Output("mkl_backprops: uint8")
2364 .Attr("T: realnumbertype")
2365 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2366 .Doc(R"doc(
2367MKL version of TanhGrad operator. Uses MKL DNN APIs to compute tanh
2368gradients for Tanh operation.
2369NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2370expected to invoke these operators.
2371)doc");
2372
2373REGISTER_OP("_MklMaxPool")
2374 .Attr("T: {float, half, bfloat16} = DT_FLOAT")
2375 .Attr("ksize: list(int) >= 4")
2376 .Attr("strides: list(int) >= 4")
2377 .Attr(GetPaddingAttrString())
2378 .Attr(GetConvnetDataFormatAttrString())
2379 .Attr(GetExplicitPaddingsAttrString())
2380 .Attr("workspace_enabled: bool = false")
2381 .Input("input: T")
2382 .Input("mkl_input: uint8")
2383 .Output("output: T")
2384#ifdef INTEL_MKL_ML_ONLY
2385 .Output("workspace: T")
2386#else
2387 .Output("workspace: uint8")
2388#endif
2389 .Output("mkl_output: uint8")
2390 .Output("mkl_workspace: uint8")
2391 .SetShapeFn(shape_inference::MaxPoolShape)
2392 .Doc(R"doc(
2393MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling
2394on the input.
2395
2396NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2397expected to invoke these operators.
2398)doc");
2399
2400REGISTER_OP("_MklMaxPoolGrad")
2401 .Attr("T: {float, half, bfloat16} = DT_FLOAT")
2402 .Attr("ksize: list(int) >= 4")
2403 .Attr("strides: list(int) >= 4")
2404 .Attr("workspace_enabled: bool = false")
2405 .Attr(GetPaddingAttrString())
2406 .Attr(GetConvnetDataFormatAttrString())
2407 .Attr(GetExplicitPaddingsAttrString())
2408 .Input("orig_input: T")
2409 .Input("orig_output: T")
2410 .Input("grad: T")
2411#ifdef INTEL_MKL_ML_ONLY
2412 .Input("workspace: T")
2413#else
2414 .Input("workspace: uint8")
2415#endif
2416 .Input("mkl_orig_input: uint8")
2417 .Input("mkl_orig_output: uint8")
2418 .Input("mkl_grad: uint8")
2419 .Input("mkl_workspace: uint8")
2420 .Output("output: T")
2421 .Output("mkl_output: uint8")
2422 .SetShapeFn(shape_inference::MaxPoolGradShape)
2423 .Doc(R"doc(
2424oneDNN version of MaxPoolGrad. Uses oneDNN APIs to compute gradients of
2425MaxPool operator.
2426
2427*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2428expected to invoke these operators.
2429)doc");
2430
2431REGISTER_OP("_MklAvgPool")
2432 .Input("value: T")
2433 .Input("mkl_input: uint8")
2434 .Output("output: T")
2435 .Output("mkl_output: uint8")
2436 .Attr("ksize: list(int) >= 4")
2437 .Attr("strides: list(int) >= 4")
2438 .Attr(GetPaddingAttrString())
2439 .Attr(GetConvnetDataFormatAttrString())
2440 .Attr("T: {float, half, double, bfloat16}")
2441 .SetShapeFn(shape_inference::AvgPoolShape)
2442 .Doc(R"doc(
2443MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling
2444on the input.
2445
2446NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2447expected to invoke these operators.
2448)doc");
2449
2450REGISTER_OP("_MklAvgPoolGrad")
2451 .Input("orig_input_shape: int32")
2452 .Input("grad: T")
2453 .Input("mkl_orig_input: uint8")
2454 .Input("mkl_grad: uint8")
2455 .Output("output: T")
2456 .Output("mkl_output: uint8")
2457 .Attr("ksize: list(int) >= 4")
2458 .Attr("strides: list(int) >= 4")
2459 .Attr(GetPaddingAttrString())
2460 .Attr(GetConvnetDataFormatAttrString())
2461 .Attr("T: {float, half, double, bfloat16}")
2462 .SetShapeFn(shape_inference::AvgPoolGradShape)
2463 .Doc(R"doc(
2464oneDNN version of AvgPoolGrad operator. Uses oneDNN APIs to compute gradients
2465of AvgPool function.
2466
2467*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2468expected to invoke these operators.
2469)doc");
2470
2471REGISTER_OP("_MklAvgPool3D")
2472 .Input("value: T")
2473 .Input("mkl_input: uint8")
2474 .Output("output: T")
2475 .Output("mkl_output: uint8")
2476 .Attr("ksize: list(int) >= 5")
2477 .Attr("strides: list(int) >= 5")
2478 .Attr(GetPaddingAttrString())
2479 .Attr(GetConvnet3dDataFormatAttrString())
2480 .Attr("T: {float, half, double, bfloat16}")
2481 .SetShapeFn(shape_inference::Pool3DShape)
2482 .Doc(R"doc(
2483MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling
2484on the input.
2485
2486NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2487expected to invoke these operators.
2488)doc");
2489
2490REGISTER_OP("_MklAvgPool3DGrad")
2491 .Input("orig_input_shape: int32")
2492 .Input("grad: T")
2493 .Input("mkl_orig_input: uint8")
2494 .Input("mkl_grad: uint8")
2495 .Output("output: T")
2496 .Output("mkl_output: uint8")
2497 .Attr("ksize: list(int) >= 5")
2498 .Attr("strides: list(int) >= 5")
2499 .Attr(GetPaddingAttrString())
2500 .Attr(GetConvnet3dDataFormatAttrString())
2501 .Attr("T: {float, half, double, bfloat16}")
2502 .SetShapeFn(shape_inference::AvgPool3DGradShape)
2503 .Doc(R"doc(
2504oneDNN version of AvgPool3DGrad operator. Uses oneDNN APIs to compute gradients
2505of AvgPool function.
2506
2507*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2508expected to invoke these operators.
2509)doc");
2510
2511REGISTER_OP("_MklMaxPool3D")
2512 .Input("input: T")
2513 .Input("mkl_input: uint8")
2514 .Output("output: T")
2515 .Output("workspace: uint8")
2516 .Output("mkl_output: uint8")
2517 .Output("mkl_workspace: uint8")
2518 .Attr("ksize: list(int) >= 5")
2519 .Attr("strides: list(int) >= 5")
2520 .Attr(GetPaddingAttrString())
2521 .Attr(GetConvnet3dDataFormatAttrString())
2522 .Attr("T: {half, bfloat16, float}")
2523 .Attr("workspace_enabled: bool = false")
2524 .SetShapeFn(shape_inference::Pool3DShape)
2525 .Doc(R"doc(
2526MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling
2527on the input.
2528
2529NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2530expected to invoke these operators.
2531)doc");
2532
2533REGISTER_OP("_MklMaxPool3DGrad")
2534 .Input("orig_input: TInput")
2535 .Input("orig_output: TInput")
2536 .Input("grad: T")
2537 .Input("workspace: uint8")
2538 .Input("mkl_orig_input: uint8")
2539 .Input("mkl_orig_output: uint8")
2540 .Input("mkl_grad: uint8")
2541 .Input("mkl_workspace: uint8")
2542 .Output("output: T")
2543 .Output("mkl_output: uint8")
2544 .Attr("ksize: list(int) >= 5")
2545 .Attr("strides: list(int) >= 5")
2546 .Attr(GetPaddingAttrString())
2547 .Attr(GetConvnet3dDataFormatAttrString())
2548 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
2549 .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
2550 .Attr("workspace_enabled: bool = false")
2551 .SetShapeFn(shape_inference::MaxPool3DGradShape)
2552 .Doc(R"doc(
2553oneDNN version of MaxPool3DGrad operator. Uses oneDNN APIs to compute gradients
2554of MaxPool3D function.
2555
2556*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2557expected to invoke these operators.
2558)doc");
2559
2560REGISTER_OP("_MklLRN")
2561 .Input("input: T")
2562 .Input("mkl_input: uint8")
2563 .Output("output: T")
2564 .Output("workspace: uint8")
2565 .Output("mkl_output: uint8")
2566 .Output("mkl_workspace: uint8")
2567 .Attr("depth_radius: int = 5")
2568 .Attr("bias: float = 1.0")
2569 .Attr("alpha: float = 1.0")
2570 .Attr("beta: float = 0.5")
2571 .Attr("workspace_enabled: bool = false")
2572 .Attr("T: {float, half} = DT_FLOAT")
2573 .SetShapeFn([](InferenceContext* c) {
2574 return UnchangedShapeWithRank(c, 4);
2575 })
2576 .Doc(R"doc(
2577MKL version of LRN operator. Uses MKL DNN APIs to perform local response
2578normalization.
2579
2580NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2581expected to invoke these operators.
2582)doc");
2583
2584REGISTER_OP("_MklLRNGrad")
2585 .Input("input_grads: T")
2586 .Input("input_image: T")
2587 .Input("output_image: T")
2588 .Input("workspace: uint8")
2589 .Input("mkl_input_grads: uint8")
2590 .Input("mkl_input_image: uint8")
2591 .Input("mkl_output_image: uint8")
2592 .Input("mkl_workspace: uint8")
2593 .Output("output: T")
2594 .Output("mkl_output: uint8")
2595 .Attr("depth_radius: int = 5")
2596 .Attr("bias: float = 1.0")
2597 .Attr("alpha: float = 1.0")
2598 .Attr("beta: float = 0.5")
2599 .Attr("workspace_enabled: bool = false")
2600 .Attr("T: {float, half} = DT_FLOAT")
2601 .SetShapeFn([](InferenceContext* c) {
2602 ShapeHandle s;
2603 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
2604 TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
2605 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
2606 c->set_output(0, s);
2607 return Status::OK();
2608 })
2609 .Doc(R"doc(
2610MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
2611local response normalization.
2612
2613NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2614expected to invoke these operators.
2615)doc");
2616
2617REGISTER_OP("_MklFusedBatchNorm")
2618 .Input("x: T")
2619 .Input("scale: T")
2620 .Input("offset: T")
2621 .Input("mean: T")
2622 .Input("variance: T")
2623 .Input("mkl_x: uint8")
2624 .Input("mkl_scale: uint8")
2625 .Input("mkl_offset: uint8")
2626 .Input("mkl_mean: uint8")
2627 .Input("mkl_variance: uint8")
2628 .Output("y: T")
2629 .Output("batch_mean: T")
2630 .Output("batch_variance: T")
2631 .Output("reserve_space_1: T")
2632 .Output("reserve_space_2: T")
2633 .Output("mkl_y: uint8")
2634 .Output("mkl_batch_mean: uint8")
2635 .Output("mkl_batch_variance: uint8")
2636 .Output("mkl_reserve_space_1: uint8")
2637 .Output("mkl_reserve_space_2: uint8")
2638 .Attr("T: numbertype")
2639 .Attr("epsilon: float = 0.0001")
2640 .Attr("data_format: string = 'NHWC'")
2641 .Attr("exponential_avg_factor: float = 1.0")
2642 .Attr("is_training: bool = true")
2643 .SetShapeFn(shape_inference::FusedBatchNormShape)
2644 .Doc(R"doc(
2645oneDNN version of FusedBatchNorm operator. Uses oneDNN APIs to perform fused
2646batch normalization.
2647
2648*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2649expected to invoke these operators.
2650)doc");
2651
2652REGISTER_OP("_MklFusedBatchNormGrad")
2653 .Input("y_backprop: T")
2654 .Input("x: T")
2655 .Input("scale: T")
2656 .Input("reserve_space_1: T")
2657 .Input("reserve_space_2: T")
2658 .Input("mkl_y_backprop: uint8")
2659 .Input("mkl_x: uint8")
2660 .Input("mkl_scale: uint8")
2661 .Input("mkl_reserve_space_1: uint8")
2662 .Input("mkl_reserve_space_2: uint8")
2663 .Output("x_backprop: T")
2664 .Output("scale_backprop: T")
2665 .Output("offset_backprop: T")
2666 .Output("reserve_space_3: T")
2667 .Output("reserve_space_4: T")
2668 .Output("mkl_x_backprop: uint8")
2669 .Output("mkl_scale_backprop: uint8")
2670 .Output("mkl_offset_backprop: uint8")
2671 .Output("mkl_reserve_space_3: uint8")
2672 .Output("mkl_reserve_space_4: uint8")
2673 .Attr("T: numbertype")
2674 .Attr("epsilon: float = 0.0001")
2675 .Attr("data_format: string = 'NHWC'")
2676 .Attr("is_training: bool = true")
2677 .SetShapeFn(shape_inference::FusedBatchNormGradShape)
2678 .Doc(R"doc(
2679oneDNN version of FusedBatchNormGrad operator. Uses oneDNN APIs to compute
2680gradients for fused batch normalization.
2681
2682*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
2683expected to invoke these operators.
2684)doc");
2685
2686REGISTER_OP("_MklFusedBatchNormV2")
2687 .Input("x: T")
2688 .Input("scale: U")
2689 .Input("offset: U")
2690 .Input("mean: U")
2691 .Input("variance: U")
2692 .Input("mkl_x: uint8")
2693 .Input("mkl_scale: uint8")
2694 .Input("mkl_offset: uint8")
2695 .Input("mkl_mean: uint8")
2696 .Input("mkl_variance: uint8")
2697 .Output("y: T")
2698 .Output("batch_mean: U")
2699 .Output("batch_variance: U")
2700 .Output("reserve_space_1: U")
2701 .Output("reserve_space_2: U")
2702 .Output("mkl_y: uint8")
2703 .Output("mkl_batch_mean: uint8")
2704 .Output("mkl_batch_variance: uint8")
2705 .Output("mkl_reserve_space_1: uint8")
2706 .Output("mkl_reserve_space_2: uint8")
2707 .Attr("T: {bfloat16, float}")
2708 .Attr("U: {float}")
2709 .Attr("epsilon: float = 0.0001")
2710 .Attr(GetConvnetDataFormatAttrString())
2711 .Attr("exponential_avg_factor: float = 1.0")
2712 .Attr("is_training: bool = true")
2713 .SetShapeFn(shape_inference::FusedBatchNormShape);
2714
2715REGISTER_OP("_MklFusedBatchNormGradV2")
2716 .Input("y_backprop: T")
2717 .Input("x: T")
2718 .Input("scale: float")
2719 .Input("reserve_space_1: U")
2720 .Input("reserve_space_2: U")
2721 .Input("mkl_y_backprop: uint8")
2722 .Input("mkl_x: uint8")
2723 .Input("mkl_scale: uint8")
2724 .Input("mkl_reserve_space_1: uint8")
2725 .Input("mkl_reserve_space_2: uint8")
2726 .Output("x_backprop: T")
2727 .Output("scale_backprop: U")
2728 .Output("offset_backprop: U")
2729 .Output("reserve_space_3: U")
2730 .Output("reserve_space_4: U")
2731 .Output("mkl_x_backprop: uint8")
2732 .Output("mkl_scale_backprop: uint8")
2733 .Output("mkl_offset_backprop: uint8")
2734 .Output("mkl_reserve_space_3: uint8")
2735 .Output("mkl_reserve_space_4: uint8")
2736 .Attr("T: {bfloat16, float}")
2737 .Attr("U: {float}")
2738 .Attr("epsilon: float = 0.0001")
2739 .Attr(GetConvnetDataFormatAttrString())
2740 .Attr("is_training: bool = true")
2741 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
2742
2743REGISTER_OP("_MklToTf")
2744 .Input("input: T")
2745 .Input("mkl_input: uint8")
2746 .Output("output: T")
2747 .Attr("T: {half, float, double, bfloat16, qint8, quint8, qint32}")
2748 .Attr(GetConvnetDataFormat2D3DAttrString())
2749 .SetShapeFn(shape_inference::UnknownShape)
2750 .Doc(R"doc(
2751MKL operator to convert a tensor from MKL layout to TensorFlow layout.
2752
2753NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2754expected to invoke these operators.
2755)doc");
2756
2757REGISTER_OP("_MklInputConversion")
2758 .Input("input_0: T")
2759 .Input("input_1: T")
2760 .Input("mkl_input_0: uint8")
2761 .Input("mkl_input_1: uint8")
2762 .Output("output_0: T")
2763 .Output("output_1: T")
2764 .Output("mkl_output_0: uint8")
2765 .Output("mkl_output_1: uint8")
2766 // All datatypes supported by element-wise ops
2767 .Attr(
2768 "T: {half, float, bfloat16, double, uint8, int8, uint16, int16, int32, "
2769 "int64, complex64, complex128}")
2770 .Attr(GetConvnetDataFormat2D3DAttrString())
2771 .SetShapeFn(shape_inference::UnknownShape)
2772 .Doc(R"doc(
2773MKL operator to process the inputs to an elementwise MKL op. Both inputs
2774need to be either in TF or in MKL format. This op is added before every
2775element-wise MKL op.
2776
2777NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2778expected to invoke these operators.
2779)doc");
2780
2781#endif // INTEL_MKL
2782REGISTER_OP("QuantizedConv2DAndRequantize")
2783 .Input("input: Tinput")
2784 .Input("filter: Tfilter")
2785 .Input("min_input: float")
2786 .Input("max_input: float")
2787 .Input("min_filter: float")
2788 .Input("max_filter: float")
2789 .Input("min_freezed_output: float")
2790 .Input("max_freezed_output: float")
2791 .Output("output: out_type")
2792 .Output("min_output: float")
2793 .Output("max_output: float")
2794 .Attr("Tinput: quantizedtype")
2795 .Attr("Tfilter: quantizedtype")
2796 .Attr("out_type: quantizedtype = DT_QINT8")
2797 .Attr("strides: list(int)")
2798 .Attr(GetPaddingAttrString())
2799 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2800 .Attr("padding_list: list(int) = []")
2801 .SetShapeFn([](InferenceContext* c) {
2802 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2803 ShapeHandle unused;
2804 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2805 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2806 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2807 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2808 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2809 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2810 c->set_output(1, c->Scalar());
2811 c->set_output(2, c->Scalar());
2812 return OkStatus();
2813 });
2814
2815// Fusion of Quantized Conv2D and BiasAdd.
2816REGISTER_OP("QuantizedConv2DWithBias")
2817 .Input("input: Tinput")
2818 .Input("filter: Tfilter")
2819 .Input("bias: float")
2820 .Input("min_input: float")
2821 .Input("max_input: float")
2822 .Input("min_filter: float")
2823 .Input("max_filter: float")
2824 .Output("output: out_type")
2825 .Output("min_output: float")
2826 .Output("max_output: float")
2827 .Attr("Tinput: quantizedtype")
2828 .Attr("Tfilter: quantizedtype")
2829 .Attr("out_type: quantizedtype = DT_QINT32")
2830 .Attr("strides: list(int)")
2831 .Attr(GetPaddingAttrString())
2832 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2833 .Attr("padding_list: list(int) = []")
2834 .SetShapeFn([](InferenceContext* c) {
2835 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2836 ShapeHandle unused, channel;
2837 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2838 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2839 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2840 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2841 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2842 c->set_output(1, channel);
2843 c->set_output(2, channel);
2844 return OkStatus();
2845 });
2846
2847REGISTER_OP("QuantizedConv2DWithBiasAndRequantize")
2848 .Input("input: Tinput")
2849 .Input("filter: Tfilter")
2850 .Input("bias: Tbias")
2851 .Input("min_input: float")
2852 .Input("max_input: float")
2853 .Input("min_filter: float")
2854 .Input("max_filter: float")
2855 .Input("min_freezed_output: float")
2856 .Input("max_freezed_output: float")
2857 .Output("output: out_type")
2858 .Output("min_output: float")
2859 .Output("max_output: float")
2860 .Attr("Tinput: quantizedtype")
2861 .Attr("Tfilter: quantizedtype")
2862 .Attr("Tbias: {float, qint32}")
2863 .Attr("out_type: quantizedtype = DT_QINT8")
2864 .Attr("strides: list(int)")
2865 .Attr(GetPaddingAttrString())
2866 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2867 .Attr("padding_list: list(int) = []")
2868 .SetShapeFn([](InferenceContext* c) {
2869 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2870 ShapeHandle unused, channel;
2871 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2872 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2873 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2874 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2875 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2876 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2877 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2878 c->set_output(1, c->Scalar());
2879 c->set_output(2, c->Scalar());
2880 return OkStatus();
2881 });
2882
2883// Fusion of Quantized Conv2D and Relu.
2884REGISTER_OP("QuantizedConv2DAndRelu")
2885 .Input("input: Tinput")
2886 .Input("filter: Tfilter")
2887 .Input("min_input: float")
2888 .Input("max_input: float")
2889 .Input("min_filter: float")
2890 .Input("max_filter: float")
2891 .Output("output: out_type")
2892 .Output("min_output: float")
2893 .Output("max_output: float")
2894 .Attr("Tinput: quantizedtype")
2895 .Attr("Tfilter: quantizedtype")
2896 .Attr("out_type: quantizedtype = DT_QINT32")
2897 .Attr("strides: list(int)")
2898 .Attr(GetPaddingAttrString())
2899 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2900 .Attr("padding_list: list(int) = []")
2901 .SetShapeFn([](InferenceContext* c) {
2902 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2903 ShapeHandle unused, channel;
2904 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2905 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2906 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
2907 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2908 c->set_output(1, channel);
2909 c->set_output(2, channel);
2910 return OkStatus();
2911 });
2912
2913REGISTER_OP("QuantizedConv2DAndReluAndRequantize")
2914 .Input("input: Tinput")
2915 .Input("filter: Tfilter")
2916 .Input("min_input: float")
2917 .Input("max_input: float")
2918 .Input("min_filter: float")
2919 .Input("max_filter: float")
2920 .Input("min_freezed_output: float")
2921 .Input("max_freezed_output: float")
2922 .Output("output: out_type")
2923 .Output("min_output: float")
2924 .Output("max_output: float")
2925 .Attr("Tinput: quantizedtype")
2926 .Attr("Tfilter: quantizedtype")
2927 .Attr("out_type: quantizedtype = DT_QUINT8")
2928 .Attr("strides: list(int)")
2929 .Attr(GetPaddingAttrString())
2930 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2931 .Attr("padding_list: list(int) = []")
2932 .SetShapeFn([](InferenceContext* c) {
2933 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2934 ShapeHandle unused, channel;
2935 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2936 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2937 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
2938 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2939 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2940 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2941 c->set_output(1, c->Scalar());
2942 c->set_output(2, c->Scalar());
2943 return OkStatus();
2944 });
2945
2946// Fusion of Quantized Conv2D, BiasAdd and Relu.
2947REGISTER_OP("QuantizedConv2DWithBiasAndRelu")
2948 .Input("input: Tinput")
2949 .Input("filter: Tfilter")
2950 .Input("bias: float")
2951 .Input("min_input: float")
2952 .Input("max_input: float")
2953 .Input("min_filter: float")
2954 .Input("max_filter: float")
2955 .Output("output: out_type")
2956 .Output("min_output: float")
2957 .Output("max_output: float")
2958 .Attr("Tinput: quantizedtype")
2959 .Attr("Tfilter: quantizedtype")
2960 .Attr("out_type: quantizedtype = DT_QINT32")
2961 .Attr("strides: list(int)")
2962 .Attr(GetPaddingAttrString())
2963 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2964 .Attr("padding_list: list(int) = []")
2965 .SetShapeFn([](InferenceContext* c) {
2966 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2967 ShapeHandle unused, channel;
2968 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2969 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2970 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2971 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
2972 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
2973 c->set_output(1, channel);
2974 c->set_output(2, channel);
2975 return OkStatus();
2976 });
2977
2978// Fusion of Quantized Conv2D, BiasAdd, Relu, and Requantize.
2979REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize")
2980 .Input("input: Tinput")
2981 .Input("filter: Tfilter")
2982 .Input("bias: Tbias")
2983 .Input("min_input: float")
2984 .Input("max_input: float")
2985 .Input("min_filter: float")
2986 .Input("max_filter: float")
2987 .Input("min_freezed_output: float")
2988 .Input("max_freezed_output: float")
2989 .Output("output: out_type")
2990 .Output("min_output: float")
2991 .Output("max_output: float")
2992 .Attr("Tinput: quantizedtype")
2993 .Attr("Tfilter: quantizedtype")
2994 .Attr("Tbias: {float, qint32}")
2995 .Attr("out_type: quantizedtype = DT_QUINT8")
2996 .Attr("strides: list(int)")
2997 .Attr(GetPaddingAttrString())
2998 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2999 .Attr("padding_list: list(int) = []")
3000 .SetShapeFn([](InferenceContext* c) {
3001 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3002 ShapeHandle unused, channel;
3003 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3004 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3005 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3006 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3007 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3008 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3009 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3010 c->set_output(1, c->Scalar());
3011 c->set_output(2, c->Scalar());
3012 return OkStatus();
3013 });
3014
3015// Fusion of Quantized Conv2D, BiasAdd, Sum, and Relu.
3016REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu")
3017 .Input("input: Tinput")
3018 .Input("filter: Tfilter")
3019 .Input("bias: float")
3020 .Input("min_input: float")
3021 .Input("max_input: float")
3022 .Input("min_filter: float")
3023 .Input("max_filter: float")
3024 .Input("summand: float")
3025 .Output("output: out_type")
3026 .Output("min_output: float")
3027 .Output("max_output: float")
3028 .Attr("Tinput: quantizedtype")
3029 .Attr("Tfilter: quantizedtype")
3030 .Attr("out_type: quantizedtype = DT_QINT32")
3031 .Attr("strides: list(int)")
3032 .Attr(GetPaddingAttrString())
3033 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3034 .Attr("padding_list: list(int) = []")
3035 .SetShapeFn([](InferenceContext* c) {
3036 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3037 ShapeHandle unused, channel;
3038 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3039 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3040 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3041 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3042 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3043 c->set_output(1, channel);
3044 c->set_output(2, channel);
3045 return OkStatus();
3046 });
3047
3048REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize")
3049 .Input("input: Tinput")
3050 .Input("filter: Tfilter")
3051 .Input("bias: Tbias")
3052 .Input("min_input: float")
3053 .Input("max_input: float")
3054 .Input("min_filter: float")
3055 .Input("max_filter: float")
3056 .Input("min_freezed_output: float")
3057 .Input("max_freezed_output: float")
3058 .Input("summand: Tsummand")
3059 .Input("min_summand: float")
3060 .Input("max_summand: float")
3061 .Output("output: out_type")
3062 .Output("min_output: float")
3063 .Output("max_output: float")
3064 .Attr("Tinput: quantizedtype")
3065 .Attr("Tfilter: quantizedtype")
3066 .Attr("Tbias: {float, qint32}")
3067 .Attr("Tsummand: quantizedtype")
3068 .Attr("out_type: quantizedtype = DT_QUINT8")
3069 .Attr("strides: list(int)")
3070 .Attr(GetPaddingAttrString())
3071 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3072 .Attr("padding_list: list(int) = []")
3073 .SetShapeFn([](InferenceContext* c) {
3074 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3075 ShapeHandle unused, channel;
3076 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3077 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3078 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3079 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3080 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3081 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3082 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3083 c->set_output(1, c->Scalar());
3084 c->set_output(2, c->Scalar());
3085 return OkStatus();
3086 });
3087
3088REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
3089 .Input("input: Tinput")
3090 .Input("filter: Tfilter")
3091 .Input("bias: Tbias")
3092 .Input("min_input: float")
3093 .Input("max_input: float")
3094 .Input("min_filter: float")
3095 .Input("max_filter: float")
3096 .Input("min_freezed_output: float")
3097 .Input("max_freezed_output: float")
3098 .Input("summand: Tsummand")
3099 .Input("min_summand: float")
3100 .Input("max_summand: float")
3101 .Output("output: out_type")
3102 .Output("min_output: float")
3103 .Output("max_output: float")
3104 .Attr("Tinput: quantizedtype")
3105 .Attr("Tfilter: quantizedtype")
3106 .Attr("Tbias: {float, qint32}")
3107 .Attr("Tsummand: quantizedtype")
3108 .Attr("out_type: quantizedtype = DT_QUINT8")
3109 .Attr("strides: list(int)")
3110 .Attr(GetPaddingAttrString())
3111 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3112 .Attr("padding_list: list(int) = []")
3113 .SetShapeFn([](InferenceContext* c) {
3114 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3115 ShapeHandle unused, channel;
3116 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3117 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3118 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3119 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3120 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel));
3121 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3122 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3123 // Since activations are not requantized per channel, `min_output`
3124 // and `max_output` are scalars.
3125 c->set_output(1, c->Scalar());
3126 c->set_output(2, c->Scalar());
3127 return OkStatus();
3128 });
3129
3130// Fusion of Quantized MatMul and BiasAdd.
3131REGISTER_OP("QuantizedMatMulWithBias")
3132 .Input("a: T1")
3133 .Input("b: T2")
3134 .Input("bias: Tbias")
3135 .Input("min_a: float")
3136 .Input("max_a: float")
3137 .Input("min_b: float")
3138 .Input("max_b: float")
3139 .Output("out: Toutput")
3140 .Output("min_out: float")
3141 .Output("max_out: float")
3142 .Attr("T1: quantizedtype")
3143 .Attr("T2: quantizedtype")
3144 .Attr("Tbias: {float, qint32}")
3145 .Attr("Toutput: quantizedtype = DT_QINT32")
3146 .Attr("transpose_a: bool = false")
3147 .Attr("transpose_b: bool = false")
3148 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
3149 .SetShapeFn([](InferenceContext* c) {
3150 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3151 ShapeHandle unused;
3152 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3153 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3154 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3155 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3156 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3157 c->set_output(1, c->Scalar());
3158 c->set_output(2, c->Scalar());
3159 return OkStatus();
3160 });
3161
3162REGISTER_OP("QuantizedMatMulWithBiasAndRelu")
3163 .Input("a: T1")
3164 .Input("b: T2")
3165 .Input("bias: float")
3166 .Input("min_a: float")
3167 .Input("max_a: float")
3168 .Input("min_b: float")
3169 .Input("max_b: float")
3170 .Output("out: Toutput")
3171 .Output("min_out: float")
3172 .Output("max_out: float")
3173 .Attr("T1: quantizedtype")
3174 .Attr("T2: quantizedtype")
3175 .Attr("Toutput: quantizedtype = DT_QINT32")
3176 .Attr("transpose_a: bool = false")
3177 .Attr("transpose_b: bool = false")
3178 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
3179 .SetShapeFn([](InferenceContext* c) {
3180 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3181 ShapeHandle unused;
3182 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3183 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3184 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3185 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3186 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3187 c->set_output(1, c->Scalar());
3188 c->set_output(2, c->Scalar());
3189 return OkStatus();
3190 });
3191
3192REGISTER_OP("QuantizedMatMulWithBiasAndReluAndRequantize")
3193 .Input("a: T1")
3194 .Input("b: T2")
3195 .Input("bias: Tbias")
3196 .Input("min_a: float")
3197 .Input("max_a: float")
3198 .Input("min_b: float")
3199 .Input("max_b: float")
3200 .Input("min_freezed_output: float")
3201 .Input("max_freezed_output: float")
3202 .Output("out: Toutput")
3203 .Output("min_out: float")
3204 .Output("max_out: float")
3205 .Attr("T1: quantizedtype")
3206 .Attr("T2: quantizedtype")
3207 .Attr("Tbias: {float, qint32}")
3208 .Attr("Toutput: quantizedtype = DT_QUINT8")
3209 .Attr("transpose_a: bool = false")
3210 .Attr("transpose_b: bool = false")
3211 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
3212 .SetShapeFn([](InferenceContext* c) {
3213 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3214 ShapeHandle unused;
3215 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3216 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3217 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3218 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3219 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3220 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3221 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3222 c->set_output(1, c->Scalar());
3223 c->set_output(2, c->Scalar());
3224 return OkStatus();
3225 });
3226
3227REGISTER_OP("QuantizedMatMulWithBiasAndDequantize")
3228 .Input("a: T1")
3229 .Input("b: T2")
3230 .Input("bias: Tbias")
3231 .Input("min_a: float")
3232 .Input("max_a: float")
3233 .Input("min_b: float")
3234 .Input("max_b: float")
3235 .Input("min_freezed_output: float")
3236 .Input("max_freezed_output: float")
3237 .Output("out: Toutput")
3238 .Attr("T1: quantizedtype")
3239 .Attr("T2: quantizedtype")
3240 .Attr("Tbias: {float, qint32}")
3241 .Attr("Toutput: {float}")
3242 .Attr("transpose_a: bool = false")
3243 .Attr("transpose_b: bool = false")
3244 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
3245 .SetShapeFn([](InferenceContext* c) {
3246 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3247 ShapeHandle unused;
3248 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3249 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3250 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3251 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3252 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3253 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3254 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3255
3256 return OkStatus();
3257 });
3258
3259REGISTER_OP("QuantizedMatMulWithBiasAndRequantize")
3260 .Input("a: T1")
3261 .Input("b: T2")
3262 .Input("bias: Tbias")
3263 .Input("min_a: float")
3264 .Input("max_a: float")
3265 .Input("min_b: float")
3266 .Input("max_b: float")
3267 .Input("min_freezed_output: float")
3268 .Input("max_freezed_output: float")
3269 .Output("out: Toutput")
3270 .Output("min_out: float")
3271 .Output("max_out: float")
3272 .Attr("T1: quantizedtype")
3273 .Attr("T2: quantizedtype")
3274 .Attr("Tbias: {float, qint32}")
3275 .Attr("Toutput: quantizedtype = DT_QUINT8")
3276 .Attr("transpose_a: bool = false")
3277 .Attr("transpose_b: bool = false")
3278 .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'")
3279 .SetShapeFn([](InferenceContext* c) {
3280 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
3281 ShapeHandle unused;
3282 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
3283 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3284 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
3285 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
3286 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
3287 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
3288 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
3289 c->set_output(1, c->Scalar());
3290 c->set_output(2, c->Scalar());
3291 return OkStatus();
3292 });
3293
3294REGISTER_OP("QuantizedConv2DPerChannel")
3295 .Input("input: Tinput")
3296 .Input("filter: Tfilter")
3297 .Input("min_input: float")
3298 .Input("max_input: float")
3299 .Input("min_filter: float")
3300 .Input("max_filter: float")
3301 .Output("output: out_type")
3302 .Output("min_output: float")
3303 .Output("max_output: float")
3304 .Attr("Tinput: quantizedtype")
3305 .Attr("Tfilter: quantizedtype")
3306 .Attr("out_type: quantizedtype = DT_QINT32")
3307 .Attr("strides: list(int)")
3308 .Attr(GetPaddingAttrString())
3309 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3310 .SetShapeFn([](InferenceContext* c) {
3311 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
3312 ShapeHandle unused, channel;
3313 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
3314 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
3315 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel));
3316 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel));
3317 c->set_output(1, channel);
3318 c->set_output(2, channel);
3319 return OkStatus();
3320 });
3321
3322REGISTER_OP("QuantizedDepthwiseConv2D")
3323 .Input("input: Tinput")
3324 .Input("filter: Tfilter")
3325 .Input("min_input: float")
3326 .Input("max_input: float")
3327 .Input("min_filter: float")
3328 .Input("max_filter: float")
3329 .Output("output: out_type")
3330 .Output("min_output: float")
3331 .Output("max_output: float")
3332 .Attr("Tinput: quantizedtype")
3333 .Attr("Tfilter: quantizedtype")
3334 .Attr("out_type: quantizedtype = DT_QINT32")
3335 .Attr("strides: list(int)")
3336 .Attr(GetPaddingAttrString())
3337 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3338 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3339
3340REGISTER_OP("QuantizedDepthwiseConv2DWithBias")
3341 .Input("input: Tinput")
3342 .Input("filter: Tfilter")
3343 .Input("bias: float")
3344 .Input("min_input: float")
3345 .Input("max_input: float")
3346 .Input("min_filter: float")
3347 .Input("max_filter: float")
3348 .Output("output: out_type")
3349 .Output("min_output: float")
3350 .Output("max_output: float")
3351 .Attr("Tinput: quantizedtype")
3352 .Attr("Tfilter: quantizedtype")
3353 .Attr("out_type: quantizedtype = DT_QINT32")
3354 .Attr("strides: list(int)")
3355 .Attr(GetPaddingAttrString())
3356 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3357 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3358
3359REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndRelu")
3360 .Input("input: Tinput")
3361 .Input("filter: Tfilter")
3362 .Input("bias: float")
3363 .Input("min_input: float")
3364 .Input("max_input: float")
3365 .Input("min_filter: float")
3366 .Input("max_filter: float")
3367 .Output("output: out_type")
3368 .Output("min_output: float")
3369 .Output("max_output: float")
3370 .Attr("Tinput: quantizedtype")
3371 .Attr("Tfilter: quantizedtype")
3372 .Attr("out_type: quantizedtype = DT_QINT32")
3373 .Attr("strides: list(int)")
3374 .Attr(GetPaddingAttrString())
3375 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3376 .Attr("padding_list: list(int) = []")
3377 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3378
3379REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize")
3380 .Input("input: Tinput")
3381 .Input("filter: Tfilter")
3382 .Input("bias: Tbias")
3383 .Input("min_input: float")
3384 .Input("max_input: float")
3385 .Input("min_filter: float")
3386 .Input("max_filter: float")
3387 .Input("min_freezed_output: float")
3388 .Input("max_freezed_output: float")
3389 .Output("output: out_type")
3390 .Output("min_output: float")
3391 .Output("max_output: float")
3392 .Attr("Tinput: quantizedtype")
3393 .Attr("Tfilter: quantizedtype")
3394 .Attr("Tbias: {float, qint32}")
3395 .Attr("out_type: quantizedtype = DT_QUINT8")
3396 .Attr("strides: list(int)")
3397 .Attr(GetPaddingAttrString())
3398 .Attr("dilations: list(int) = [1, 1, 1, 1]")
3399 .Attr("padding_list: list(int) = []")
3400 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
3401
3402REGISTER_OP("IsotonicRegression")
3403 .Input("input: T")
3404 .Output("output: output_dtype")
3405 .Output("segments: int32")
3406 .Attr("T: realnumbertype")
3407 .Attr("output_dtype: {half, bfloat16, float, double} = DT_FLOAT")
3408 .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* context) {
3409 context->set_output(0, context->input(0));
3410 context->set_output(1, context->input(0));
3411 return OkStatus();
3412 });
3413
3414} // namespace tensorflow
3415