1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <cmath> |
18 | |
19 | #include "tensorflow/core/framework/common_shape_fns.h" |
20 | #include "tensorflow/core/framework/kernel_shape_util.h" |
21 | #include "tensorflow/core/framework/numeric_op.h" |
22 | #include "tensorflow/core/framework/op.h" |
23 | #include "tensorflow/core/framework/shape_inference.h" |
24 | #include "tensorflow/core/lib/core/bits.h" |
25 | #include "tensorflow/core/lib/math/math_util.h" |
26 | #include "tensorflow/core/util/mirror_pad_mode.h" |
27 | #include "tensorflow/core/util/padding.h" |
28 | #include "tensorflow/core/util/tensor_format.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | using shape_inference::DimensionHandle; |
33 | using shape_inference::InferenceContext; |
34 | using shape_inference::ShapeHandle; |
35 | |
36 | namespace { |
37 | |
38 | Status FractionalPoolShapeFn(InferenceContext* c) { |
39 | ShapeHandle input; |
40 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
41 | |
42 | std::vector<float> pooling_ratio; |
43 | TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio" , &pooling_ratio)); |
44 | if (pooling_ratio.size() != 4) { |
45 | return errors::InvalidArgument( |
46 | "pooling_ratio field must specify 4 dimensions" ); |
47 | } |
48 | std::vector<DimensionHandle> output_dims; |
49 | for (int i = 0; i < 4; ++i) { |
50 | DimensionHandle d = c->Dim(input, i); |
51 | if (c->ValueKnown(d)) { |
52 | // This must match the same logic in the kernel function in |
53 | // core/kernels/fractional_max_pool_op.cc. |
54 | auto val = |
55 | static_cast<int64_t>(std::floor(c->Value(d) / pooling_ratio[i])); |
56 | if (val < 0) { |
57 | return errors::InvalidArgument("Size computed for dim " , i, |
58 | " is negative: " , val); |
59 | } |
60 | output_dims.push_back(c->MakeDim(val)); |
61 | } else { |
62 | output_dims.push_back(c->UnknownDim()); |
63 | } |
64 | } |
65 | |
66 | c->set_output(0, c->MakeShape(output_dims)); |
67 | c->set_output(1, c->Vector(output_dims[1])); |
68 | c->set_output(2, c->Vector(output_dims[2])); |
69 | return OkStatus(); |
70 | } |
71 | |
72 | } // namespace |
73 | |
74 | // -------------------------------------------------------------------------- |
75 | |
76 | REGISTER_OP("AvgPool" ) |
77 | .Input("value: T" ) |
78 | .Output("output: T" ) |
79 | .Attr("ksize: list(int) >= 4" ) |
80 | .Attr("strides: list(int) >= 4" ) |
81 | .Attr(GetPaddingAttrString()) |
82 | .Attr(GetConvnetDataFormatAttrString()) |
83 | .Attr("T: {half, bfloat16, float, double}" ) |
84 | .SetShapeFn(shape_inference::AvgPoolShape); |
85 | |
86 | REGISTER_OP("AvgPoolGrad" ) |
87 | .Input("orig_input_shape: int32" ) |
88 | .Input("grad: T" ) |
89 | .Output("output: T" ) |
90 | .Attr("ksize: list(int) >= 4" ) |
91 | .Attr("strides: list(int) >= 4" ) |
92 | .Attr(GetPaddingAttrString()) |
93 | .Attr(GetConvnetDataFormatAttrString()) |
94 | .Attr("T: {half, bfloat16, float, double}" ) |
95 | .SetShapeFn(shape_inference::AvgPoolGradShape); |
96 | |
97 | // -------------------------------------------------------------------------- |
98 | |
99 | REGISTER_OP("BatchNormWithGlobalNormalization" ) |
100 | .Input("t: T" ) |
101 | .Input("m: T" ) |
102 | .Input("v: T" ) |
103 | .Input("beta: T" ) |
104 | .Input("gamma: T" ) |
105 | .Output("result: T" ) |
106 | .Attr("T: numbertype" ) |
107 | .Attr("variance_epsilon: float" ) |
108 | .Attr("scale_after_normalization: bool" ) |
109 | .Deprecated(9, "Use tf.nn.batch_normalization()" ) |
110 | .SetShapeFn([](InferenceContext* c) { |
111 | ShapeHandle input; |
112 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
113 | |
114 | DimensionHandle last_dim = c->Dim(input, 3); |
115 | for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma |
116 | ShapeHandle vec; |
117 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); |
118 | TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); |
119 | } |
120 | |
121 | ShapeHandle out; |
122 | TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out)); |
123 | c->set_output(0, out); |
124 | return OkStatus(); |
125 | }); |
126 | |
127 | REGISTER_OP("BatchNormWithGlobalNormalizationGrad" ) |
128 | .Input("t: T" ) |
129 | .Input("m: T" ) |
130 | .Input("v: T" ) |
131 | .Input("gamma: T" ) |
132 | .Input("backprop: T" ) |
133 | .Output("dx: T" ) |
134 | .Output("dm: T" ) |
135 | .Output("dv: T" ) |
136 | .Output("db: T" ) |
137 | .Output("dg: T" ) |
138 | .Attr("T: numbertype" ) |
139 | .Attr("variance_epsilon: float" ) |
140 | .Attr("scale_after_normalization: bool" ) |
141 | .Deprecated(9, "Use tf.nn.batch_normalization()" ) |
142 | .SetShapeFn([](InferenceContext* c) { |
143 | ShapeHandle input; |
144 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
145 | TF_RETURN_IF_ERROR( |
146 | c->Merge(input, c->input(4), &input)); // with backprop |
147 | |
148 | DimensionHandle last_dim = c->Dim(input, 3); |
149 | for (int i = 1; i < 4; ++i) { // covers m, v, gamma |
150 | ShapeHandle vec; |
151 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec)); |
152 | TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); |
153 | } |
154 | |
155 | ShapeHandle dx; |
156 | TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx)); |
157 | c->set_output(0, dx); |
158 | |
159 | ShapeHandle vector_shape = c->Vector(last_dim); |
160 | c->set_output(1, vector_shape); |
161 | c->set_output(2, vector_shape); |
162 | c->set_output(3, vector_shape); |
163 | c->set_output(4, vector_shape); |
164 | return OkStatus(); |
165 | }); |
166 | |
167 | // -------------------------------------------------------------------------- |
168 | |
169 | REGISTER_OP("FusedBatchNorm" ) |
170 | .Input("x: T" ) |
171 | .Input("scale: T" ) |
172 | .Input("offset: T" ) |
173 | .Input("mean: T" ) |
174 | .Input("variance: T" ) |
175 | .Output("y: T" ) |
176 | .Output("batch_mean: T" ) |
177 | .Output("batch_variance: T" ) |
178 | .Output("reserve_space_1: T" ) |
179 | .Output("reserve_space_2: T" ) |
180 | .Attr("T: {float}" ) |
181 | .Attr("epsilon: float = 0.0001" ) |
182 | .Attr("exponential_avg_factor: float = 1.0" ) |
183 | .Attr(GetConvnetDataFormatAttrString()) |
184 | .Attr("is_training: bool = true" ) |
185 | .SetShapeFn(shape_inference::FusedBatchNormShape); |
186 | |
187 | REGISTER_OP("FusedBatchNormV2" ) |
188 | .Input("x: T" ) |
189 | .Input("scale: U" ) |
190 | .Input("offset: U" ) |
191 | .Input("mean: U" ) |
192 | .Input("variance: U" ) |
193 | .Output("y: T" ) |
194 | .Output("batch_mean: U" ) |
195 | .Output("batch_variance: U" ) |
196 | .Output("reserve_space_1: U" ) |
197 | .Output("reserve_space_2: U" ) |
198 | .Attr("T: {half, bfloat16, float}" ) |
199 | .Attr("U: {float}" ) |
200 | .Attr("epsilon: float = 0.0001" ) |
201 | .Attr("exponential_avg_factor: float = 1.0" ) |
202 | .Attr(GetConvnetDataFormatAttrString()) |
203 | .Attr("is_training: bool = true" ) |
204 | .SetShapeFn(shape_inference::FusedBatchNormShape); |
205 | |
206 | REGISTER_OP("FusedBatchNormV3" ) |
207 | .Input("x: T" ) |
208 | .Input("scale: U" ) |
209 | .Input("offset: U" ) |
210 | .Input("mean: U" ) |
211 | .Input("variance: U" ) |
212 | .Output("y: T" ) |
213 | .Output("batch_mean: U" ) |
214 | .Output("batch_variance: U" ) |
215 | .Output("reserve_space_1: U" ) |
216 | .Output("reserve_space_2: U" ) |
217 | .Output("reserve_space_3: U" ) |
218 | .Attr("T: {half, bfloat16, float}" ) |
219 | .Attr("U: {bfloat16, float}" ) |
220 | .Attr("epsilon: float = 0.0001" ) |
221 | .Attr("exponential_avg_factor: float = 1.0" ) |
222 | .Attr(GetConvnetDataFormat2D3DAttrString()) |
223 | .Attr("is_training: bool = true" ) |
224 | .SetShapeFn(shape_inference::FusedBatchNormV3Shape); |
225 | |
226 | REGISTER_OP("_FusedBatchNormEx" ) |
227 | .Input("x: T" ) |
228 | .Input("scale: U" ) |
229 | .Input("offset: U" ) |
230 | .Input("mean: U" ) |
231 | .Input("variance: U" ) |
232 | .Input("side_input: num_side_inputs * T" ) |
233 | .Output("y: T" ) |
234 | .Output("batch_mean: U" ) |
235 | .Output("batch_variance: U" ) |
236 | .Output("reserve_space_1: U" ) |
237 | .Output("reserve_space_2: U" ) |
238 | .Output("reserve_space_3: U" ) |
239 | .Attr("T: {half, float, bfloat16}" ) |
240 | .Attr("U: {float}" ) |
241 | .Attr("epsilon: float = 0.0001" ) |
242 | .Attr("exponential_avg_factor: float = 1.0" ) |
243 | .Attr("num_side_inputs: int >= 0 = 0" ) |
244 | .Attr("activation_mode: string = \"Identity\"" ) |
245 | .Attr(GetConvnetDataFormatAttrString()) |
246 | .Attr("is_training: bool = true" ) |
247 | .SetShapeFn(shape_inference::FusedBatchNormExShape) |
248 | .Doc(R"doc( |
249 | Internal FusedBatchNorm operation: reserved for internal use. |
250 | |
251 | Do not invoke this operator directly in Python. A fusion optimization is |
252 | expected to create these operators. |
253 | )doc" ); |
254 | |
255 | REGISTER_OP("FusedBatchNormGrad" ) |
256 | .Input("y_backprop: T" ) |
257 | .Input("x: T" ) |
258 | .Input("scale: T" ) |
259 | .Input("reserve_space_1: T" ) |
260 | .Input("reserve_space_2: T" ) |
261 | .Output("x_backprop: T" ) |
262 | .Output("scale_backprop: T" ) |
263 | .Output("offset_backprop: T" ) |
264 | .Output("reserve_space_3: T" ) |
265 | .Output("reserve_space_4: T" ) |
266 | .Attr("T: {float}" ) |
267 | .Attr("epsilon: float = 0.0001" ) |
268 | .Attr(GetConvnetDataFormatAttrString()) |
269 | .Attr("is_training: bool = true" ) |
270 | .SetShapeFn(shape_inference::FusedBatchNormGradShape); |
271 | |
272 | REGISTER_OP("FusedBatchNormGradV2" ) |
273 | .Input("y_backprop: T" ) |
274 | .Input("x: T" ) |
275 | .Input("scale: float" ) |
276 | .Input("reserve_space_1: U" ) |
277 | .Input("reserve_space_2: U" ) |
278 | .Output("x_backprop: T" ) |
279 | .Output("scale_backprop: U" ) |
280 | .Output("offset_backprop: U" ) |
281 | .Output("reserve_space_3: U" ) |
282 | .Output("reserve_space_4: U" ) |
283 | .Attr("T: {half, bfloat16, float}" ) |
284 | .Attr("U: {float}" ) |
285 | .Attr("epsilon: float = 0.0001" ) |
286 | .Attr(GetConvnetDataFormatAttrString()) |
287 | .Attr("is_training: bool = true" ) |
288 | .SetShapeFn(shape_inference::FusedBatchNormGradShape); |
289 | |
290 | REGISTER_OP("FusedBatchNormGradV3" ) |
291 | .Input("y_backprop: T" ) |
292 | .Input("x: T" ) |
293 | .Input("scale: float" ) |
294 | .Input("reserve_space_1: U" ) |
295 | .Input("reserve_space_2: U" ) |
296 | .Input("reserve_space_3: U" ) |
297 | .Output("x_backprop: T" ) |
298 | .Output("scale_backprop: U" ) |
299 | .Output("offset_backprop: U" ) |
300 | .Output("reserve_space_4: U" ) |
301 | .Output("reserve_space_5: U" ) |
302 | .Attr("T: {half, bfloat16, float}" ) |
303 | .Attr("U: {float}" ) |
304 | .Attr("epsilon: float = 0.0001" ) |
305 | .Attr(GetConvnetDataFormat2D3DAttrString()) |
306 | .Attr("is_training: bool = true" ) |
307 | .SetShapeFn(shape_inference::FusedBatchNormGradShape); |
308 | |
309 | REGISTER_OP("_FusedBatchNormGradEx" ) |
310 | .Input("y_backprop: T" ) |
311 | .Input("x: T" ) |
312 | .Input("scale: float" ) |
313 | .Input("reserve_space_1: U" ) |
314 | .Input("reserve_space_2: U" ) |
315 | .Input("reserve_space_3: U" ) |
316 | .Input("offset: float" ) |
317 | .Input("y: T" ) |
318 | .Output("x_backprop: T" ) |
319 | .Output("scale_backprop: U" ) |
320 | .Output("offset_backprop: U" ) |
321 | .Output("reserve_space_4: U" ) |
322 | .Output("reserve_space_5: U" ) |
323 | .Output("side_input_backprop: num_side_inputs * T" ) |
324 | .Attr("T: {half, float}" ) |
325 | .Attr("U: {float}" ) |
326 | .Attr("epsilon: float = 0.0001" ) |
327 | .Attr("num_side_inputs: int >= 0 = 0" ) |
328 | .Attr("activation_mode: string = \"Identity\"" ) |
329 | .Attr(GetConvnetDataFormat2D3DAttrString()) |
330 | .Attr("is_training: bool = true" ) |
331 | .SetShapeFn(shape_inference::FusedBatchNormGradExShape) |
332 | .Doc(R"doc( |
333 | Internal FusedBatchNormGrad operation: reserved for internal use. |
334 | |
335 | Do not invoke this operator directly in Python. A fusion optimization is |
336 | expected to create these operators. |
337 | )doc" ); |
338 | // -------------------------------------------------------------------------- |
339 | |
340 | REGISTER_OP("BiasAdd" ) |
341 | .Attr("T: numbertype" ) |
342 | .Input("value: T" ) |
343 | .Input("bias: T" ) |
344 | .Attr(GetConvnetDataFormatAttrString()) |
345 | .Output("output: T" ) |
346 | .SetShapeFn(shape_inference::BiasAddShape); |
347 | // -------------------------------------------------------------------------- |
348 | |
349 | REGISTER_OP("BiasAddGrad" ) |
350 | .Attr("T: numbertype" ) |
351 | .Input("out_backprop: T" ) |
352 | .Attr(GetConvnetDataFormatAttrString()) |
353 | .Output("output: T" ) |
354 | .SetShapeFn(shape_inference::BiasAddGradShape); |
355 | // -------------------------------------------------------------------------- |
356 | |
357 | REGISTER_OP("BiasAddV1" ) |
358 | .Attr("T: numbertype" ) |
359 | .Input("value: T" ) |
360 | .Input("bias: T" ) |
361 | .Output("output: T" ) |
362 | .SetShapeFn(shape_inference::BiasAddShape); |
363 | // -------------------------------------------------------------------------- |
364 | |
365 | REGISTER_OP("Conv2D" ) |
366 | .Input("input: T" ) |
367 | .Input("filter: T" ) |
368 | .Output("output: T" ) |
369 | .Attr("T: {half, bfloat16, float, double, int32}" ) |
370 | .Attr("strides: list(int)" ) |
371 | .Attr("use_cudnn_on_gpu: bool = true" ) |
372 | .Attr(GetPaddingAttrStringWithExplicit()) |
373 | .Attr(GetExplicitPaddingsAttrString()) |
374 | .Attr(GetConvnetDataFormatAttrString()) |
375 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
376 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding); |
377 | |
378 | REGISTER_OP("Conv2DBackpropInput" ) |
379 | .Input("input_sizes: int32" ) |
380 | .Input("filter: T" ) |
381 | .Input("out_backprop: T" ) |
382 | .Output("output: T" ) |
383 | .Attr("T: {half, bfloat16, float, double, int32}" ) |
384 | .Attr("strides: list(int)" ) |
385 | .Attr("use_cudnn_on_gpu: bool = true" ) |
386 | .Attr(GetPaddingAttrStringWithExplicit()) |
387 | .Attr(GetExplicitPaddingsAttrString()) |
388 | .Attr(GetConvnetDataFormatAttrString()) |
389 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
390 | .SetShapeFn(shape_inference::Conv2DBackpropInputShape); |
391 | |
392 | // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a |
393 | // more general string attribute ('kernel_impl'?) that can be used to |
394 | // select among several possible implementations. |
395 | REGISTER_OP("Conv2DBackpropFilter" ) |
396 | .Input("input: T" ) |
397 | .Input("filter_sizes: int32" ) |
398 | .Input("out_backprop: T" ) |
399 | .Output("output: T" ) |
400 | .Attr("T: {half, bfloat16, float, double}" ) |
401 | .Attr("strides: list(int)" ) |
402 | .Attr("use_cudnn_on_gpu: bool = true" ) |
403 | .Attr(GetPaddingAttrStringWithExplicit()) |
404 | .Attr(GetExplicitPaddingsAttrString()) |
405 | .Attr(GetConvnetDataFormatAttrString()) |
406 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
407 | .SetShapeFn([](InferenceContext* c) { |
408 | ShapeHandle s; |
409 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
410 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
411 | c->set_output(0, s); |
412 | return OkStatus(); |
413 | }); |
414 | |
415 | REGISTER_OP("_FusedConv2D" ) |
416 | .Input("input: T" ) |
417 | .Input("filter: T" ) |
418 | .Input("args: TArgs" ) |
419 | .Input("host_args : num_host_args * float" ) |
420 | .Output("output: T" ) |
421 | .Attr("T: {half, float, double, int8, qint8}" ) |
422 | .Attr("TArgs: list(type)" ) |
423 | .Attr("num_args: int >= 0" ) |
424 | .Attr("num_host_args: int >= 0 =0" ) |
425 | .Attr("strides: list(int)" ) |
426 | .Attr(GetPaddingAttrStringWithExplicit()) |
427 | .Attr(GetExplicitPaddingsAttrString()) |
428 | .Attr("data_format: { 'NHWC', 'NCHW', 'NCHW_VECT_C' } = 'NHWC'" ) |
429 | .Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'" ) |
430 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
431 | .Attr("use_cudnn_on_gpu: bool = true" ) |
432 | .Attr("fused_ops: list(string) = []" ) |
433 | // Attributes for the FusedBatchNorm ------------------------------------ // |
434 | .Attr("epsilon: float = 0.0001" ) |
435 | // Attributes for the LeakyRelu ----------------------------------------- // |
436 | .Attr("leakyrelu_alpha: float = 0.2" ) |
437 | // ---------------------------------------------------------------------- // |
438 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) |
439 | .Doc(R"doc( |
440 | Performs a convolution followed by a specified series of operations. |
441 | |
442 | The inputs to the convolution are `input` and `filter`. The series of operations |
443 | that follows is specified by the `fused_ops` attribute, which is a list of TF op |
444 | names specified as strings (e.g. "Relu"). They are performed in order, where the |
445 | (first) input to each op is the output of the preceding op. The first input and |
446 | the output of each fused_op must be of type T. |
447 | |
448 | Currently supported fused_op combinations are: [X] and [X,A], where X is one of |
449 | {"BiasAdd","FusedBatchNorm"} and A is one of {"Elu","Relu","Relu6"}. |
450 | |
451 | * The first input to op X is the Conv2D result, and the additional input(s) to X |
452 | are specified by `args`. |
453 | * If there is an op A specified, the output of op X is the input to op A, and op |
454 | A produces the _FusedConv2D output. Otherwise, op X produces the _FusedConv2D |
455 | output. |
456 | |
457 | *NOTE*: Do not invoke this operator directly in Python. Grappler is expected to |
458 | create these operators. |
459 | )doc" ); |
460 | |
461 | namespace { |
462 | |
463 | Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) { |
464 | ShapeHandle input; |
465 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
466 | |
467 | ShapeHandle resized = input; |
468 | int paddings_index = 1; |
469 | int filter_index = 2; |
470 | if (has_resize) { |
471 | paddings_index = 2; |
472 | filter_index = 3; |
473 | |
474 | ShapeHandle unused_size; |
475 | TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size)); |
476 | |
477 | const Tensor* size = c->input_tensor(1); |
478 | DimensionHandle new_height = c->UnknownDim(); |
479 | DimensionHandle new_width = c->UnknownDim(); |
480 | if (size != nullptr) { |
481 | new_height = c->MakeDim(size->flat<int32>()(0)); |
482 | new_width = c->MakeDim(size->flat<int32>()(1)); |
483 | } |
484 | TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized)); |
485 | TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized)); |
486 | } |
487 | |
488 | ShapeHandle paddings; |
489 | TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings)); |
490 | TF_RETURN_IF_ERROR( |
491 | c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized)); |
492 | TF_RETURN_IF_ERROR( |
493 | c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings)); |
494 | |
495 | const Tensor* paddings_t = c->input_tensor(paddings_index); |
496 | ShapeHandle padded; |
497 | if (paddings_t != nullptr) { |
498 | std::vector<DimensionHandle> output_dims; |
499 | for (int i = 0; i < 4; ++i) { |
500 | DimensionHandle dim = c->Dim(resized, i); |
501 | int64_t p0 = static_cast<int64_t>(paddings_t->matrix<int32>()(i, 0)); |
502 | int64_t p1 = static_cast<int64_t>(paddings_t->matrix<int32>()(i, 1)); |
503 | if (p0 < 0 || p1 < 0) { |
504 | return errors::InvalidArgument("Paddings must be non-negative" ); |
505 | } |
506 | |
507 | TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim)); |
508 | output_dims.push_back(dim); |
509 | } |
510 | padded = c->MakeShape(output_dims); |
511 | } else { |
512 | padded = c->UnknownShapeOfRank(4); |
513 | } |
514 | |
515 | // Work out the convolution's effect with 'padded' as the input. |
516 | ShapeHandle filter; |
517 | TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter)); |
518 | std::vector<int32> strides; |
519 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
520 | if (strides.size() != 4) { |
521 | return errors::InvalidArgument( |
522 | "Operation requires the stride attribute to contain 4 values, but " , |
523 | "got: " , strides.size()); |
524 | } |
525 | |
526 | int32_t stride_rows = strides[1]; |
527 | int32_t stride_cols = strides[2]; |
528 | |
529 | DimensionHandle batch_size_dim = c->Dim(padded, 0); |
530 | DimensionHandle in_rows_dim = c->Dim(padded, 1); |
531 | DimensionHandle in_cols_dim = c->Dim(padded, 2); |
532 | DimensionHandle filter_rows_dim = c->Dim(filter, 0); |
533 | DimensionHandle filter_cols_dim = c->Dim(filter, 1); |
534 | DimensionHandle output_depth_dim = c->Dim(filter, 3); |
535 | |
536 | DimensionHandle unused; |
537 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused)); |
538 | |
539 | Padding padding; |
540 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
541 | |
542 | DimensionHandle output_rows, output_cols; |
543 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
544 | c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows)); |
545 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( |
546 | c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols)); |
547 | |
548 | ShapeHandle output_shape = c->MakeShape( |
549 | {batch_size_dim, output_rows, output_cols, output_depth_dim}); |
550 | c->set_output(0, output_shape); |
551 | return OkStatus(); |
552 | } |
553 | |
554 | } // namespace |
555 | |
556 | REGISTER_OP("DataFormatDimMap" ) |
557 | .Input("x: T" ) |
558 | .Output("y: T" ) |
559 | .Attr("T: {int32, int64} = DT_INT32" ) |
560 | .Attr("src_format: string = 'NHWC'" ) |
561 | .Attr("dst_format: string = 'NCHW'" ) |
562 | .SetShapeFn(shape_inference::UnchangedShape); |
563 | |
564 | REGISTER_OP("DataFormatVecPermute" ) |
565 | .Input("x: T" ) |
566 | .Output("y: T" ) |
567 | .Attr("T: {int32, int64} = DT_INT32" ) |
568 | .Attr("src_format: string = 'NHWC'" ) |
569 | .Attr("dst_format: string = 'NCHW'" ) |
570 | .SetShapeFn(shape_inference::UnchangedShape); |
571 | |
572 | REGISTER_OP("FusedResizeAndPadConv2D" ) |
573 | .Input("input: T" ) |
574 | .Input("size: int32" ) |
575 | .Input("paddings: int32" ) |
576 | .Input("filter: T" ) |
577 | .Output("output: T" ) |
578 | .Attr("T: {half, float, double}" ) |
579 | .Attr("resize_align_corners: bool = false" ) |
580 | .Attr(GetMirrorPadModeAttrString()) |
581 | .Attr("strides: list(int)" ) |
582 | .Attr(GetPaddingAttrString()) |
583 | .SetShapeFn([](InferenceContext* c) { |
584 | return CommonFusedConvCalculations(c, true /* has_resize */); |
585 | }); |
586 | |
587 | REGISTER_OP("FusedPadConv2D" ) |
588 | .Input("input: T" ) |
589 | .Input("paddings: int32" ) |
590 | .Input("filter: T" ) |
591 | .Output("output: T" ) |
592 | .Attr("T: {half, float, double}" ) |
593 | .Attr(GetMirrorPadModeAttrString()) |
594 | .Attr("strides: list(int)" ) |
595 | .Attr(GetPaddingAttrString()) |
596 | .SetShapeFn([](InferenceContext* c) { |
597 | return CommonFusedConvCalculations(c, false /* has_resize */); |
598 | }); |
599 | |
600 | // -------------------------------------------------------------------------- |
601 | |
602 | REGISTER_OP("DepthwiseConv2dNative" ) |
603 | .Input("input: T" ) |
604 | .Input("filter: T" ) |
605 | .Output("output: T" ) |
606 | .Attr("T: {half, bfloat16, float, double}" ) |
607 | .Attr("strides: list(int)" ) |
608 | .Attr(GetPaddingAttrStringWithExplicit()) |
609 | .Attr(GetExplicitPaddingsAttrString()) |
610 | .Attr(GetConvnetDataFormatAttrString()) |
611 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
612 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding); |
613 | |
614 | REGISTER_OP("DepthwiseConv2dNativeBackpropInput" ) |
615 | .Input("input_sizes: int32" ) |
616 | .Input("filter: T" ) |
617 | .Input("out_backprop: T" ) |
618 | .Output("output: T" ) |
619 | .Attr("T: {half, bfloat16, float, double}" ) |
620 | .Attr("strides: list(int)" ) |
621 | .Attr(GetPaddingAttrStringWithExplicit()) |
622 | .Attr(GetExplicitPaddingsAttrString()) |
623 | .Attr(GetConvnetDataFormatAttrString()) |
624 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
625 | .SetShapeFn([](InferenceContext* c) { |
626 | ShapeHandle s; |
627 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
628 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
629 | c->set_output(0, s); |
630 | return OkStatus(); |
631 | }); |
632 | |
633 | REGISTER_OP("DepthwiseConv2dNativeBackpropFilter" ) |
634 | .Input("input: T" ) |
635 | .Input("filter_sizes: int32" ) |
636 | .Input("out_backprop: T" ) |
637 | .Output("output: T" ) |
638 | .Attr("T: {half, bfloat16, float, double}" ) |
639 | .Attr("strides: list(int)" ) |
640 | .Attr(GetPaddingAttrStringWithExplicit()) |
641 | .Attr(GetExplicitPaddingsAttrString()) |
642 | .Attr(GetConvnetDataFormatAttrString()) |
643 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
644 | .SetShapeFn([](InferenceContext* c) { |
645 | ShapeHandle s; |
646 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
647 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
648 | c->set_output(0, s); |
649 | return OkStatus(); |
650 | }); |
651 | |
652 | REGISTER_OP("_FusedDepthwiseConv2dNative" ) |
653 | .Input("input: T" ) |
654 | .Input("filter: T" ) |
655 | .Input("args: num_args * T" ) |
656 | .Output("output: T" ) |
657 | .Attr("T: {half, bfloat16, float, double}" ) |
658 | .Attr("num_args: int >= 0" ) |
659 | .Attr("strides: list(int)" ) |
660 | .Attr(GetPaddingAttrString()) |
661 | .Attr(GetConvnetDataFormatAttrString()) |
662 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
663 | .Attr("fused_ops: list(string) = []" ) |
664 | // Attributes for the FusedBatchNorm ------------------------------------ // |
665 | .Attr("epsilon: float = 0.0001" ) |
666 | // Attributes for the LeakyRelu ----------------------------------------- // |
667 | .Attr("leakyrelu_alpha: float = 0.2" ) |
668 | // ---------------------------------------------------------------------- // |
669 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); |
670 | |
671 | // -------------------------------------------------------------------------- |
672 | |
673 | REGISTER_OP("Conv3D" ) |
674 | .Input("input: T" ) |
675 | .Input("filter: T" ) |
676 | .Output("output: T" ) |
677 | .Attr("T: {half, bfloat16, float, double}" ) |
678 | .Attr("strides: list(int) >= 5" ) |
679 | .Attr(GetPaddingAttrString()) |
680 | .Attr(GetConvnet3dDataFormatAttrString()) |
681 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
682 | .SetShapeFn(shape_inference::Conv3DShape); |
683 | |
684 | REGISTER_OP("Conv3DBackpropInput" ) |
685 | .Input("input: T" ) |
686 | .Input("filter: T" ) |
687 | .Input("out_backprop: T" ) |
688 | .Output("output: T" ) |
689 | .Attr("T: {half, float, double}" ) |
690 | .Attr("strides: list(int) >= 5" ) |
691 | .Attr(GetPaddingAttrString()) |
692 | .Deprecated(10, "Use Conv3DBackpropInputV2" ) |
693 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
694 | .SetShapeFn([](InferenceContext* c) { |
695 | return UnchangedShapeWithRank(c, 5); |
696 | }); |
697 | |
698 | REGISTER_OP("Conv3DBackpropFilter" ) |
699 | .Input("input: T" ) |
700 | .Input("filter: T" ) |
701 | .Input("out_backprop: T" ) |
702 | .Output("output: T" ) |
703 | .Attr("T: {half, float, double}" ) |
704 | .Attr("strides: list(int) >= 5" ) |
705 | .Attr(GetPaddingAttrString()) |
706 | .Deprecated(10, "Use Conv3DBackpropFilterV2" ) |
707 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
708 | .SetShapeFn([](InferenceContext* c) { |
709 | ShapeHandle out; |
710 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out)); |
711 | c->set_output(0, out); |
712 | return OkStatus(); |
713 | }); |
714 | |
715 | REGISTER_OP("Conv3DBackpropInputV2" ) |
716 | .Input("input_sizes: Tshape" ) |
717 | .Input("filter: T" ) |
718 | .Input("out_backprop: T" ) |
719 | .Output("output: T" ) |
720 | .Attr("T: {half, bfloat16, float, double}" ) |
721 | .Attr("strides: list(int) >= 5" ) |
722 | .Attr(GetPaddingAttrString()) |
723 | .Attr(GetConvnet3dDataFormatAttrString()) |
724 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
725 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
726 | .SetShapeFn([](InferenceContext* c) { |
727 | ShapeHandle s; |
728 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
729 | TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); |
730 | c->set_output(0, s); |
731 | return OkStatus(); |
732 | }); |
733 | |
734 | REGISTER_OP("Conv3DBackpropFilterV2" ) |
735 | .Input("input: T" ) |
736 | .Input("filter_sizes: int32" ) |
737 | .Input("out_backprop: T" ) |
738 | .Output("output: T" ) |
739 | .Attr("T: {half, bfloat16, float, double}" ) |
740 | .Attr("strides: list(int) >= 5" ) |
741 | .Attr(GetPaddingAttrString()) |
742 | .Attr(GetConvnet3dDataFormatAttrString()) |
743 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
744 | .SetShapeFn([](InferenceContext* c) { |
745 | ShapeHandle s; |
746 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
747 | TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); |
748 | c->set_output(0, s); |
749 | return OkStatus(); |
750 | }); |
751 | |
752 | // -------------------------------------------------------------------------- |
753 | |
754 | REGISTER_OP("AvgPool3D" ) |
755 | .Input("input: T" ) |
756 | .Output("output: T" ) |
757 | .Attr("ksize: list(int) >= 5" ) |
758 | .Attr("strides: list(int) >= 5" ) |
759 | .Attr(GetPaddingAttrString()) |
760 | .Attr(GetConvnet3dDataFormatAttrString()) |
761 | .Attr("T: {half, bfloat16, float, double}" ) |
762 | .SetShapeFn(shape_inference::Pool3DShape); |
763 | |
764 | REGISTER_OP("AvgPool3DGrad" ) |
765 | .Input("orig_input_shape: int32" ) |
766 | .Input("grad: T" ) |
767 | .Output("output: T" ) |
768 | .Attr("ksize: list(int) >= 5" ) |
769 | .Attr("strides: list(int) >= 5" ) |
770 | .Attr(GetPaddingAttrString()) |
771 | .Attr(GetConvnet3dDataFormatAttrString()) |
772 | .Attr("T: {half, bfloat16, float, double}" ) |
773 | .SetShapeFn(shape_inference::AvgPool3DGradShape); |
774 | |
775 | // -------------------------------------------------------------------------- |
776 | |
777 | REGISTER_OP("MaxPool3D" ) |
778 | .Input("input: T" ) |
779 | .Output("output: T" ) |
780 | .Attr("ksize: list(int) >= 5" ) |
781 | .Attr("strides: list(int) >= 5" ) |
782 | .Attr(GetPaddingAttrString()) |
783 | .Attr(GetConvnet3dDataFormatAttrString()) |
784 | .Attr("T: {half, bfloat16, float}" ) |
785 | .SetShapeFn(shape_inference::Pool3DShape); |
786 | |
787 | REGISTER_OP("MaxPool3DGrad" ) |
788 | .Input("orig_input: TInput" ) |
789 | .Input("orig_output: TInput" ) |
790 | .Input("grad: T" ) |
791 | .Output("output: T" ) |
792 | .Attr("ksize: list(int) >= 5" ) |
793 | .Attr("strides: list(int) >= 5" ) |
794 | .Attr(GetPaddingAttrString()) |
795 | .Attr(GetConvnet3dDataFormatAttrString()) |
796 | .Attr("T: {half, bfloat16, float} = DT_FLOAT" ) |
797 | .Attr("TInput: {half, bfloat16, float} = DT_FLOAT" ) |
798 | .SetShapeFn(shape_inference::MaxPool3DGradShape); |
799 | |
800 | REGISTER_OP("MaxPool3DGradGrad" ) |
801 | .Input("orig_input: T" ) |
802 | .Input("orig_output: T" ) |
803 | .Input("grad: T" ) |
804 | .Output("output: T" ) |
805 | .Attr("ksize: list(int) >= 5 " ) |
806 | .Attr("strides: list(int) >= 5" ) |
807 | .Attr(GetPaddingAttrString()) |
808 | .Attr(GetConvnet3dDataFormatAttrString()) |
809 | .Attr("T: realnumbertype" ) |
810 | .SetShapeFn([](InferenceContext* c) { |
811 | TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c)); |
812 | ShapeHandle unused; |
813 | // Validate 'orig_input' is the same shape as 'grad' |
814 | TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); |
815 | // Validate 'orig_output' is same shape as 'output' |
816 | TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); |
817 | return OkStatus(); |
818 | }); |
819 | |
820 | // -------------------------------------------------------------------------- |
821 | |
822 | REGISTER_OP("L2Loss" ) |
823 | .Input("t: T" ) |
824 | .Output("output: T" ) |
825 | .Attr("T: {half, bfloat16, float, double}" ) |
826 | .SetShapeFn(shape_inference::ScalarShape); |
827 | |
828 | // -------------------------------------------------------------------------- |
829 | |
830 | REGISTER_OP("LRN" ) |
831 | .Input("input: T" ) |
832 | .Output("output: T" ) |
833 | .Attr("depth_radius: int = 5" ) |
834 | .Attr("bias: float = 1.0" ) |
835 | .Attr("alpha: float = 1.0" ) |
836 | .Attr("beta: float = 0.5" ) |
837 | .Attr("T: {half, bfloat16, float} = DT_FLOAT" ) |
838 | .SetShapeFn([](InferenceContext* c) { |
839 | return UnchangedShapeWithRank(c, 4); |
840 | }); |
841 | |
842 | REGISTER_OP("LRNGrad" ) |
843 | .Input("input_grads: T" ) |
844 | .Input("input_image: T" ) |
845 | .Input("output_image: T" ) |
846 | .Output("output: T" ) |
847 | .Attr("depth_radius: int = 5" ) |
848 | .Attr("bias: float = 1.0" ) |
849 | .Attr("alpha: float = 1.0" ) |
850 | .Attr("beta: float = 0.5" ) |
851 | .Attr("T: {half, bfloat16, float} = DT_FLOAT" ) |
852 | .SetShapeFn([](InferenceContext* c) { |
853 | ShapeHandle s; |
854 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads |
855 | TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image |
856 | TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image |
857 | c->set_output(0, s); |
858 | return OkStatus(); |
859 | }); |
860 | |
861 | // -------------------------------------------------------------------------- |
862 | |
863 | REGISTER_OP("MaxPool" ) |
864 | .Attr( |
865 | "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, " |
866 | "uint16, qint8} = DT_FLOAT" ) |
867 | .Attr("ksize: list(int) >= 4" ) |
868 | .Attr("strides: list(int) >= 4" ) |
869 | .Attr(GetPaddingAttrStringWithExplicit()) |
870 | .Attr(GetExplicitPaddingsAttrString()) |
871 | .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'" ) |
872 | .Input("input: T" ) |
873 | .Output("output: T" ) |
874 | .SetShapeFn(shape_inference::MaxPoolShapeWithExplicitPadding); |
875 | |
876 | REGISTER_OP("MaxPoolV2" ) |
877 | .Attr( |
878 | "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, " |
879 | "uint16, qint8} = DT_FLOAT" ) |
880 | .Attr(GetPaddingAttrString()) |
881 | .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'" ) |
882 | .Input("input: T" ) |
883 | .Input("ksize: int32" ) |
884 | .Input("strides: int32" ) |
885 | .Output("output: T" ) |
886 | .SetShapeFn([](InferenceContext* c) { |
887 | TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3)); |
888 | return OkStatus(); |
889 | }); |
890 | |
891 | REGISTER_OP("MaxPoolGrad" ) |
892 | .Attr("ksize: list(int) >= 4" ) |
893 | .Attr("strides: list(int) >= 4" ) |
894 | .Attr(GetPaddingAttrStringWithExplicit()) |
895 | .Attr(GetExplicitPaddingsAttrString()) |
896 | .Attr(GetConvnetDataFormatAttrString()) |
897 | .Input("orig_input: T" ) |
898 | .Input("orig_output: T" ) |
899 | .Input("grad: T" ) |
900 | .Output("output: T" ) |
901 | .Attr("T: realnumbertype = DT_FLOAT" ) |
902 | .SetShapeFn(shape_inference::MaxPoolGradShape); |
903 | |
904 | REGISTER_OP("MaxPoolGradV2" ) |
905 | .Attr(GetPaddingAttrString()) |
906 | .Attr(GetConvnetDataFormatAttrString()) |
907 | .Input("orig_input: T" ) |
908 | .Input("orig_output: T" ) |
909 | .Input("grad: T" ) |
910 | .Input("ksize: int32" ) |
911 | .Input("strides: int32" ) |
912 | .Output("output: T" ) |
913 | .Attr("T: realnumbertype = DT_FLOAT" ) |
914 | .SetShapeFn(shape_inference::MaxPoolGradShape); |
915 | |
916 | // TODO(b/150813181): Implement explicit padding. |
917 | REGISTER_OP("MaxPoolGradGrad" ) |
918 | .Attr("ksize: list(int) >= 4" ) |
919 | .Attr("strides: list(int) >= 4" ) |
920 | .Attr(GetPaddingAttrString()) |
921 | .Attr(GetConvnetDataFormatAttrString()) |
922 | .Input("orig_input: T" ) |
923 | .Input("orig_output: T" ) |
924 | .Input("grad: T" ) |
925 | .Output("output: T" ) |
926 | .Attr("T: realnumbertype" ) |
927 | .SetShapeFn([](InferenceContext* c) { |
928 | TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); |
929 | ShapeHandle unused; |
930 | // Validate 'orig_input' is the same shape as 'grad' |
931 | TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); |
932 | // Validate 'orig_output' is same shape as 'output' |
933 | TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); |
934 | return OkStatus(); |
935 | }); |
936 | |
937 | REGISTER_OP("MaxPoolGradGradV2" ) |
938 | .Attr(GetPaddingAttrString()) |
939 | .Attr(GetConvnetDataFormatAttrString()) |
940 | .Input("orig_input: T" ) |
941 | .Input("orig_output: T" ) |
942 | .Input("grad: T" ) |
943 | .Input("ksize: int32" ) |
944 | .Input("strides: int32" ) |
945 | .Output("output: T" ) |
946 | .Attr("T: realnumbertype" ) |
947 | .SetShapeFn([](InferenceContext* c) { |
948 | TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5)); |
949 | ShapeHandle unused; |
950 | // Validate 'orig_input' is the same shape as 'grad' |
951 | TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused)); |
952 | // Validate 'orig_output' is same shape as 'output' |
953 | TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused)); |
954 | return OkStatus(); |
955 | }); |
956 | |
957 | REGISTER_OP("MaxPoolWithArgmax" ) |
958 | .Attr("ksize: list(int) >= 4" ) |
959 | .Attr("strides: list(int) >= 4" ) |
960 | .Attr("Targmax: {int32, int64} = DT_INT64" ) |
961 | .Attr(GetPaddingAttrString()) |
962 | .Attr("include_batch_in_index: bool = false" ) |
963 | .Input("input: T" ) |
964 | .Output("output: T" ) |
965 | .Output("argmax: Targmax" ) |
966 | .Attr("T: realnumbertype" ) |
967 | .SetShapeFn([](InferenceContext* c) { |
968 | TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); |
969 | c->set_output(1, c->output(0)); |
970 | return OkStatus(); |
971 | }); |
972 | |
973 | REGISTER_OP("MaxPoolGradWithArgmax" ) |
974 | .Attr("ksize: list(int) >= 4" ) |
975 | .Attr("strides: list(int) >= 4" ) |
976 | .Attr(GetPaddingAttrString()) |
977 | .Attr("include_batch_in_index: bool = false" ) |
978 | .Attr("Targmax: {int32, int64}" ) |
979 | .Input("input: T" ) |
980 | .Input("grad: T" ) |
981 | .Input("argmax: Targmax" ) |
982 | .Output("output: T" ) |
983 | .Attr("T: realnumbertype" ) |
984 | .SetShapeFn([](InferenceContext* c) { |
985 | return UnchangedShapeWithRank(c, 4); |
986 | }); |
987 | |
988 | REGISTER_OP("MaxPoolGradGradWithArgmax" ) |
989 | .Attr("ksize: list(int) >= 4" ) |
990 | .Attr("strides: list(int) >= 4" ) |
991 | .Attr(GetPaddingAttrString()) |
992 | .Attr("include_batch_in_index: bool = false" ) |
993 | .Attr("Targmax: {int32, int64}" ) |
994 | .Input("input: T" ) |
995 | .Input("grad: T" ) |
996 | .Input("argmax: Targmax" ) |
997 | .Output("output: T" ) |
998 | .Attr("T: realnumbertype" ) |
999 | .SetShapeFn([](InferenceContext* c) { |
1000 | TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); |
1001 | ShapeHandle unused; |
1002 | // Validate 'orig_input' is the same shape as 'grad' |
1003 | TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused)); |
1004 | // Validate 'argmax' is same shape as 'output' |
1005 | TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused)); |
1006 | return OkStatus(); |
1007 | }); |
1008 | |
1009 | // -------------------------------------------------------------------------- |
1010 | |
1011 | REGISTER_OP("Dilation2D" ) |
1012 | .Input("input: T" ) |
1013 | .Input("filter: T" ) |
1014 | .Output("output: T" ) |
1015 | .Attr("T: realnumbertype" ) |
1016 | .Attr("strides: list(int) >= 4" ) |
1017 | .Attr("rates: list(int) >= 4" ) |
1018 | .Attr(GetPaddingAttrString()) |
1019 | .SetShapeFn([](InferenceContext* c) { |
1020 | ShapeHandle input_shape; |
1021 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
1022 | ShapeHandle filter_shape; |
1023 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape)); |
1024 | |
1025 | std::vector<int32> strides; |
1026 | TF_RETURN_IF_ERROR(c->GetAttr("strides" , &strides)); |
1027 | if (strides.size() != 4) { |
1028 | return errors::InvalidArgument( |
1029 | "Dilation2D requires the stride attribute to contain 4 values, but " |
1030 | "got: " , |
1031 | strides.size()); |
1032 | } |
1033 | |
1034 | std::vector<int32> rates; |
1035 | TF_RETURN_IF_ERROR(c->GetAttr("rates" , &rates)); |
1036 | if (rates.size() != 4) { |
1037 | return errors::InvalidArgument( |
1038 | "Dilation2D requires the rates attribute to contain 4 values, but " |
1039 | "got: " , |
1040 | rates.size()); |
1041 | } |
1042 | |
1043 | int32_t stride_rows = strides[1]; |
1044 | int32_t stride_cols = strides[2]; |
1045 | |
1046 | int32_t rate_rows = rates[1]; |
1047 | int32_t rate_cols = rates[2]; |
1048 | |
1049 | DimensionHandle batch_size_dim = c->Dim(input_shape, 0); |
1050 | DimensionHandle in_rows_dim = c->Dim(input_shape, 1); |
1051 | DimensionHandle in_cols_dim = c->Dim(input_shape, 2); |
1052 | DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0); |
1053 | DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1); |
1054 | DimensionHandle output_depth_dim = c->Dim(filter_shape, 2); |
1055 | |
1056 | if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim) || |
1057 | !c->ValueKnown(filter_rows_dim) || !c->ValueKnown(filter_cols_dim)) { |
1058 | ShapeHandle output_shape = |
1059 | c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, |
1060 | InferenceContext::kUnknownDim, output_depth_dim}); |
1061 | c->set_output(0, output_shape); |
1062 | return OkStatus(); |
1063 | } |
1064 | DimensionHandle unused; |
1065 | TF_RETURN_IF_ERROR( |
1066 | c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused)); |
1067 | |
1068 | auto in_rows = c->Value(in_rows_dim); |
1069 | auto in_cols = c->Value(in_cols_dim); |
1070 | auto filter_rows = c->Value(filter_rows_dim); |
1071 | auto filter_cols = c->Value(filter_cols_dim); |
1072 | auto filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_rows - 1); |
1073 | auto filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_cols - 1); |
1074 | |
1075 | Padding padding; |
1076 | TF_RETURN_IF_ERROR(c->GetAttr("padding" , &padding)); |
1077 | |
1078 | int64_t output_rows, output_cols; |
1079 | int64_t padding_before, padding_after; |
1080 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
1081 | in_rows, filter_rows_eff, stride_rows, padding, &output_rows, |
1082 | &padding_before, &padding_after)); |
1083 | TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( |
1084 | in_cols, filter_cols_eff, stride_cols, padding, &output_cols, |
1085 | &padding_before, &padding_after)); |
1086 | |
1087 | ShapeHandle output_shape = c->MakeShape( |
1088 | {batch_size_dim, output_rows, output_cols, output_depth_dim}); |
1089 | c->set_output(0, output_shape); |
1090 | return OkStatus(); |
1091 | }); |
1092 | |
1093 | REGISTER_OP("Dilation2DBackpropInput" ) |
1094 | .Input("input: T" ) |
1095 | .Input("filter: T" ) |
1096 | .Input("out_backprop: T" ) |
1097 | .Output("in_backprop: T" ) |
1098 | .Attr("T: realnumbertype" ) |
1099 | .Attr("strides: list(int) >= 4" ) |
1100 | .Attr("rates: list(int) >= 4" ) |
1101 | .Attr(GetPaddingAttrString()) |
1102 | .SetShapeFn(shape_inference::UnchangedShape); |
1103 | |
1104 | REGISTER_OP("Dilation2DBackpropFilter" ) |
1105 | .Input("input: T" ) |
1106 | .Input("filter: T" ) |
1107 | .Input("out_backprop: T" ) |
1108 | .Output("filter_backprop: T" ) |
1109 | .Attr("T: realnumbertype" ) |
1110 | .Attr("strides: list(int) >= 4" ) |
1111 | .Attr("rates: list(int) >= 4" ) |
1112 | .Attr(GetPaddingAttrString()) |
1113 | .SetShapeFn([](InferenceContext* c) { |
1114 | c->set_output(0, c->input(1)); |
1115 | return OkStatus(); |
1116 | }); |
1117 | |
1118 | // -------------------------------------------------------------------------- |
1119 | |
1120 | REGISTER_OP("Relu" ) |
1121 | .Input("features: T" ) |
1122 | .Output("activations: T" ) |
1123 | .Attr("T: {realnumbertype, qint8}" ) |
1124 | .SetShapeFn(shape_inference::UnchangedShape); |
1125 | |
1126 | REGISTER_OP("ReluGrad" ) |
1127 | .Input("gradients: T" ) |
1128 | .Input("features: T" ) |
1129 | .Output("backprops: T" ) |
1130 | .Attr("T: realnumbertype" ) |
1131 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
1132 | |
1133 | REGISTER_OP("Relu6" ) |
1134 | .Input("features: T" ) |
1135 | .Output("activations: T" ) |
1136 | .Attr("T: realnumbertype" ) |
1137 | .SetShapeFn(shape_inference::UnchangedShape); |
1138 | |
1139 | REGISTER_OP("Relu6Grad" ) |
1140 | .Input("gradients: T" ) |
1141 | .Input("features: T" ) |
1142 | .Output("backprops: T" ) |
1143 | .Attr("T: realnumbertype" ) |
1144 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
1145 | |
1146 | REGISTER_OP("LeakyRelu" ) |
1147 | .Input("features: T" ) |
1148 | .Output("activations: T" ) |
1149 | .Attr("alpha: float = 0.2" ) |
1150 | .Attr("T: {half, bfloat16, float, double} = DT_FLOAT" ) |
1151 | .SetShapeFn(shape_inference::UnchangedShape); |
1152 | |
1153 | REGISTER_OP("LeakyReluGrad" ) |
1154 | .Input("gradients: T" ) |
1155 | .Input("features: T" ) |
1156 | .Output("backprops: T" ) |
1157 | .Attr("alpha: float = 0.2" ) |
1158 | .Attr("T: {half, bfloat16, float, double} = DT_FLOAT" ) |
1159 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
1160 | |
1161 | REGISTER_OP("Elu" ) |
1162 | .Input("features: T" ) |
1163 | .Output("activations: T" ) |
1164 | .Attr("T: {half, bfloat16, float, double}" ) |
1165 | .SetShapeFn(shape_inference::UnchangedShape); |
1166 | |
1167 | REGISTER_OP("EluGrad" ) |
1168 | .Input("gradients: T" ) |
1169 | .Input("outputs: T" ) |
1170 | .Output("backprops: T" ) |
1171 | .Attr("T: {half, bfloat16, float, double}" ) |
1172 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
1173 | |
1174 | REGISTER_OP("Selu" ) |
1175 | .Input("features: T" ) |
1176 | .Output("activations: T" ) |
1177 | .Attr("T: {half, bfloat16, float, double}" ) |
1178 | .SetShapeFn(shape_inference::UnchangedShape); |
1179 | |
1180 | REGISTER_OP("SeluGrad" ) |
1181 | .Input("gradients: T" ) |
1182 | .Input("outputs: T" ) |
1183 | .Output("backprops: T" ) |
1184 | .Attr("T: {half, bfloat16, float, double}" ) |
1185 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
1186 | |
1187 | REGISTER_OP("Softplus" ) |
1188 | .Input("features: T" ) |
1189 | .Output("activations: T" ) |
1190 | .Attr("T: {half, bfloat16, float, double}" ) |
1191 | .SetShapeFn(shape_inference::UnchangedShape); |
1192 | |
1193 | REGISTER_OP("SoftplusGrad" ) |
1194 | .Input("gradients: T" ) |
1195 | .Input("features: T" ) |
1196 | .Output("backprops: T" ) |
1197 | .Attr("T: {half, bfloat16, float, double}" ) |
1198 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
1199 | |
1200 | REGISTER_OP("Softsign" ) |
1201 | .Input("features: T" ) |
1202 | .Output("activations: T" ) |
1203 | .Attr("T: {half, bfloat16, float, double}" ) |
1204 | .SetShapeFn(shape_inference::UnchangedShape); |
1205 | |
1206 | REGISTER_OP("SoftsignGrad" ) |
1207 | .Input("gradients: T" ) |
1208 | .Input("features: T" ) |
1209 | .Output("backprops: T" ) |
1210 | .Attr("T: {half, bfloat16, float, double}" ) |
1211 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
1212 | |
1213 | // -------------------------------------------------------------------------- |
1214 | |
1215 | REGISTER_OP("Softmax" ) |
1216 | .Input("logits: T" ) |
1217 | .Output("softmax: T" ) |
1218 | .Attr("T: {half, bfloat16, float, double}" ) |
1219 | .SetShapeFn([](InferenceContext* c) { |
1220 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); |
1221 | }); |
1222 | |
1223 | // -------------------------------------------------------------------------- |
1224 | |
1225 | REGISTER_OP("LogSoftmax" ) |
1226 | .Input("logits: T" ) |
1227 | .Output("logsoftmax: T" ) |
1228 | .Attr("T: {half, bfloat16, float, double}" ) |
1229 | .SetShapeFn([](InferenceContext* c) { |
1230 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); |
1231 | }); |
1232 | |
1233 | // -------------------------------------------------------------------------- |
1234 | |
1235 | REGISTER_OP("SoftmaxCrossEntropyWithLogits" ) |
1236 | .Input("features: T" ) |
1237 | .Input("labels: T" ) |
1238 | .Output("loss: T" ) |
1239 | .Output("backprop: T" ) |
1240 | .Attr("T: {half, bfloat16, float, double}" ) |
1241 | .SetShapeFn([](InferenceContext* c) { |
1242 | ShapeHandle input; |
1243 | if (c->WithRank(c->input(0), 2, &input) == OkStatus() && |
1244 | c->Merge(input, c->input(1), &input) == OkStatus()) { |
1245 | DimensionHandle batch_size = c->Dim(input, 0); |
1246 | c->set_output(0, c->Vector(batch_size)); |
1247 | c->set_output(1, input); |
1248 | return OkStatus(); |
1249 | } |
1250 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1)); |
1251 | |
1252 | if (!c->RankKnown(c->output(1))) { |
1253 | return errors::InvalidArgument( |
1254 | "Shape must be broadcasted with rank 2, but is rank is unknown." ); |
1255 | } |
1256 | |
1257 | if (c->Rank(c->output(1)) != 2) { |
1258 | return errors::InvalidArgument( |
1259 | "Shape must be broadcasted with rank 2, but is rank " , |
1260 | c->Rank(c->output(1))); |
1261 | } |
1262 | DimensionHandle batch_size = c->Dim(c->output(1), 0); |
1263 | c->set_output(0, c->Vector(batch_size)); |
1264 | return OkStatus(); |
1265 | }); |
1266 | |
1267 | REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits" ) |
1268 | .Input("features: T" ) |
1269 | .Input("labels: Tlabels" ) |
1270 | .Output("loss: T" ) |
1271 | .Output("backprop: T" ) |
1272 | .Attr("T: {half, bfloat16, float, double}" ) |
1273 | .Attr("Tlabels: {int32, int64} = DT_INT64" ) |
1274 | .SetShapeFn([](InferenceContext* c) { |
1275 | ShapeHandle features; |
1276 | ShapeHandle labels; |
1277 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features)); |
1278 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels)); |
1279 | |
1280 | DimensionHandle batch_size; |
1281 | TF_RETURN_IF_ERROR( |
1282 | c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size)); |
1283 | TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features)); |
1284 | |
1285 | c->set_output(0, c->Vector(batch_size)); |
1286 | c->set_output(1, features); |
1287 | return OkStatus(); |
1288 | }); |
1289 | |
1290 | // -------------------------------------------------------------------------- |
1291 | |
1292 | REGISTER_OP("InTopK" ) |
1293 | .Input("predictions: float" ) |
1294 | .Input("targets: T" ) |
1295 | .Output("precision: bool" ) |
1296 | .Attr("k: int" ) |
1297 | .Attr("T: {int32, int64} = DT_INT32" ) |
1298 | .SetShapeFn([](InferenceContext* c) { |
1299 | ShapeHandle predictions; |
1300 | ShapeHandle targets; |
1301 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions)); |
1302 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets)); |
1303 | DimensionHandle batch_size; |
1304 | TF_RETURN_IF_ERROR( |
1305 | c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size)); |
1306 | c->set_output(0, c->Vector(batch_size)); |
1307 | return OkStatus(); |
1308 | }); |
1309 | |
1310 | // This is the same as `InTopK`, but takes `k` as in input rather than an attr. |
1311 | REGISTER_OP("InTopKV2" ) |
1312 | .Input("predictions: float" ) |
1313 | .Input("targets: T" ) |
1314 | .Input("k: T" ) |
1315 | .Output("precision: bool" ) |
1316 | .Attr("T: {int32, int64} = DT_INT32" ) |
1317 | .SetShapeFn([](InferenceContext* c) { |
1318 | ShapeHandle predictions; |
1319 | ShapeHandle targets; |
1320 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions)); |
1321 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets)); |
1322 | DimensionHandle batch_size; |
1323 | TF_RETURN_IF_ERROR( |
1324 | c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size)); |
1325 | c->set_output(0, c->Vector(batch_size)); |
1326 | return OkStatus(); |
1327 | }); |
1328 | |
1329 | namespace { |
1330 | |
1331 | Status TopKShapeFn(InferenceContext* c) { |
1332 | ShapeHandle input; |
1333 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); |
1334 | |
1335 | // Get the k value, either from input tensor or attribute. |
1336 | DimensionHandle k_dim; |
1337 | if (c->num_inputs() >= 2) { |
1338 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim)); |
1339 | } else { |
1340 | int32_t k; |
1341 | TF_RETURN_IF_ERROR(c->GetAttr("k" , &k)); |
1342 | if (k < 0) { |
1343 | return errors::InvalidArgument("Need k >= 0, got " , k); |
1344 | } |
1345 | k_dim = c->MakeDim(k); |
1346 | } |
1347 | |
1348 | DimensionHandle last_dim = c->Dim(input, -1); |
1349 | if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) && |
1350 | c->Value(last_dim) < c->Value(k_dim)) { |
1351 | return errors::InvalidArgument( |
1352 | "input must have last dimension >= k = " , c->Value(k_dim), " but is " , |
1353 | c->Value(last_dim)); |
1354 | } |
1355 | |
1356 | // Replace last_dim with k_dim. |
1357 | ShapeHandle s; |
1358 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s)); |
1359 | TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s)); |
1360 | c->set_output(0, s); |
1361 | c->set_output(1, s); |
1362 | return OkStatus(); |
1363 | } |
1364 | |
1365 | // Utility functions for ApproxTopKShape. |
1366 | // It is not easy to link xla/client/lib into the tensorflow core lib, so we |
1367 | // have to replicate the logic. |
1368 | // LINT.IfChange |
1369 | inline uint32_t log2_floor(uint64_t value) { |
1370 | return value == 0 ? 0 : Log2Floor(value); |
1371 | } |
1372 | |
1373 | inline uint32_t log2_ceil(uint64_t value) { |
1374 | return value == 0 ? 0 : Log2Ceiling(value); |
1375 | } |
1376 | |
1377 | Status ApproxTopKShape(shape_inference::InferenceContext* c) { |
1378 | int64_t k; |
1379 | int64_t reduction_dimension; |
1380 | float recall_target; |
1381 | int64_t reduction_input_size_override; |
1382 | bool aggregate_to_topk; |
1383 | TF_RETURN_IF_ERROR(c->GetAttr("k" , &k)); |
1384 | TF_RETURN_IF_ERROR(c->GetAttr("reduction_dimension" , &reduction_dimension)); |
1385 | TF_RETURN_IF_ERROR(c->GetAttr("recall_target" , &recall_target)); |
1386 | TF_RETURN_IF_ERROR(c->GetAttr("reduction_input_size_override" , |
1387 | &reduction_input_size_override)); |
1388 | TF_RETURN_IF_ERROR(c->GetAttr("aggregate_to_topk" , &aggregate_to_topk)); |
1389 | ShapeHandle input_shape; |
1390 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input_shape)); |
1391 | if (reduction_dimension < 0) { |
1392 | // Reverse index |
1393 | reduction_dimension += c->Rank(input_shape); |
1394 | } |
1395 | int64_t reduction_dim_value = |
1396 | c->Value(c->Dim(input_shape, reduction_dimension)); |
1397 | |
1398 | if (reduction_dim_value < k) { |
1399 | return errors::InvalidArgument("input must have last dimension >= k = " , k, |
1400 | " but was " , reduction_dim_value); |
1401 | } |
1402 | |
1403 | int64_t output_dim_value = [&] { |
1404 | if (aggregate_to_topk) { |
1405 | return k; |
1406 | } |
1407 | int64_t tpu_tiling = c->Rank(input_shape) == 1 ? 1024 : 128; |
1408 | if (reduction_dim_value <= tpu_tiling || recall_target == 1.0) { |
1409 | return reduction_dim_value; |
1410 | } |
1411 | if (k == 1) { |
1412 | return tpu_tiling; |
1413 | } |
1414 | uint64_t logical_input_size = reduction_input_size_override >= 0 |
1415 | ? reduction_input_size_override |
1416 | : reduction_dim_value; |
1417 | uint64_t m = std::min<uint64_t>( |
1418 | std::max<uint64_t>( |
1419 | static_cast<uint64_t>((1.0 - k) / |
1420 | std::log(static_cast<double>(recall_target))), |
1421 | tpu_tiling), |
1422 | reduction_dim_value); |
1423 | uint32_t log2_reduction = log2_floor(logical_input_size / m); |
1424 | if (log2_reduction == 0) { |
1425 | return reduction_dim_value; |
1426 | } |
1427 | log2_reduction = std::min<uint32_t>( |
1428 | log2_reduction, log2_ceil(reduction_dim_value / tpu_tiling)); |
1429 | return tensorflow::MathUtil::CeilOfRatio<int64_t>( |
1430 | tensorflow::MathUtil::CeilOfRatio<int64_t>(reduction_dim_value, |
1431 | tpu_tiling), |
1432 | (1 << log2_reduction)) * |
1433 | tpu_tiling; |
1434 | }(); |
1435 | |
1436 | auto output_dim = c->MakeDim(output_dim_value); |
1437 | |
1438 | ShapeHandle output_shape; |
1439 | TF_RETURN_IF_ERROR(c->ReplaceDim(input_shape, reduction_dimension, output_dim, |
1440 | &output_shape)); |
1441 | c->set_output(0, output_shape); |
1442 | c->set_output(1, output_shape); |
1443 | return OkStatus(); |
1444 | } |
1445 | // LINT.ThenChange(//tensorflow/compiler/xla/client/lib/approx_topk_shape.cc) |
1446 | |
1447 | } // namespace |
1448 | |
1449 | REGISTER_OP("TopK" ) |
1450 | .Input("input: T" ) |
1451 | .Output("values: T" ) |
1452 | .Output("indices: int32" ) |
1453 | .Attr("k: int >= 0" ) |
1454 | .Attr("sorted: bool = true" ) |
1455 | .Attr("T: realnumbertype" ) |
1456 | .Deprecated(7, "Use TopKV2 instead" ) |
1457 | .SetShapeFn(TopKShapeFn); |
1458 | |
1459 | // This is the same as `TopK`, but takes `k` as in input rather than an attr. |
1460 | REGISTER_OP("TopKV2" ) |
1461 | .Input("input: T" ) |
1462 | .Input("k: int32" ) |
1463 | .Output("values: T" ) |
1464 | .Output("indices: int32" ) |
1465 | .Attr("sorted: bool = true" ) |
1466 | .Attr("T: realnumbertype" ) |
1467 | .SetShapeFn(TopKShapeFn); |
1468 | |
1469 | REGISTER_OP("ApproxTopK" ) |
1470 | .Input("input: T" ) |
1471 | .Output("values: T" ) |
1472 | .Output("indices: int32" ) |
1473 | .Attr("k: int >= 0" ) |
1474 | .Attr("reduction_dimension: int = -1" ) |
1475 | .Attr("recall_target: float = 0.95" ) |
1476 | .Attr("is_max_k: bool = true" ) |
1477 | .Attr("reduction_input_size_override: int = -1" ) |
1478 | .Attr("aggregate_to_topk: bool = true" ) |
1479 | .Attr("T: {half, bfloat16, float}" ) |
1480 | .SetShapeFn(ApproxTopKShape); |
1481 | |
1482 | // -------------------------------------------------------------------------- |
1483 | |
1484 | REGISTER_OP("NthElement" ) |
1485 | .Input("input: T" ) |
1486 | .Input("n: int32" ) |
1487 | .Output("values: T" ) |
1488 | .Attr("reverse: bool = false" ) |
1489 | .Attr("T: realnumbertype" ) |
1490 | .SetShapeFn([](InferenceContext* c) { |
1491 | ShapeHandle input; |
1492 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input)); |
1493 | |
1494 | // Get the n value from input tensor, and make sure which is a scalar. |
1495 | DimensionHandle n_dim; |
1496 | TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &n_dim)); |
1497 | |
1498 | // The last dimension of input tensor must be greater than N. |
1499 | DimensionHandle last_dim = c->Dim(input, -1); |
1500 | if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) && |
1501 | c->Value(last_dim) <= c->Value(n_dim)) { |
1502 | return errors::InvalidArgument( |
1503 | "Input must have last dimension > n = " , c->Value(n_dim), |
1504 | " but is " , c->Value(last_dim)); |
1505 | } |
1506 | |
1507 | // Reduce last_dim for output tensor |
1508 | ShapeHandle s; |
1509 | TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s)); |
1510 | c->set_output(0, s); |
1511 | return OkStatus(); |
1512 | }); |
1513 | |
1514 | // -------------------------------------------------------------------------- |
1515 | |
1516 | REGISTER_OP("FractionalMaxPool" ) |
1517 | .Input("value: T" ) |
1518 | .Output("output: T" ) |
1519 | .Output("row_pooling_sequence: int64" ) |
1520 | .Output("col_pooling_sequence: int64" ) |
1521 | .Attr("pooling_ratio: list(float) >=4" ) |
1522 | .Attr("pseudo_random: bool = false" ) |
1523 | .Attr("overlapping: bool = false" ) |
1524 | .Attr("deterministic: bool = false" ) |
1525 | .Attr("seed: int = 0" ) |
1526 | .Attr("seed2: int = 0" ) |
1527 | .Attr("T: {float, double, int32, int64}" ) |
1528 | .SetShapeFn(FractionalPoolShapeFn); |
1529 | |
1530 | REGISTER_OP("FractionalMaxPoolGrad" ) |
1531 | .Input("orig_input: T" ) |
1532 | .Input("orig_output: T" ) |
1533 | .Input("out_backprop: T" ) |
1534 | .Input("row_pooling_sequence: int64" ) |
1535 | .Input("col_pooling_sequence: int64" ) |
1536 | .Output("output: T" ) |
1537 | .Attr("overlapping: bool = false" ) |
1538 | .Attr("T: {float, double, int32, int64}" ) |
1539 | .SetShapeFn([](InferenceContext* c) { |
1540 | return shape_inference::UnchangedShapeWithRank(c, 4); |
1541 | }); |
1542 | |
1543 | // -------------------------------------------------------------------------- |
1544 | |
1545 | REGISTER_OP("FractionalAvgPool" ) |
1546 | .Input("value: T" ) |
1547 | .Output("output: T" ) |
1548 | .Output("row_pooling_sequence: int64" ) |
1549 | .Output("col_pooling_sequence: int64" ) |
1550 | .Attr("pooling_ratio: list(float) >=4" ) |
1551 | .Attr("pseudo_random: bool = false" ) |
1552 | .Attr("overlapping: bool = false" ) |
1553 | .Attr("deterministic: bool = false" ) |
1554 | .Attr("seed: int = 0" ) |
1555 | .Attr("seed2: int = 0" ) |
1556 | .Attr("T: {float, double, int32, int64}" ) |
1557 | .SetShapeFn(FractionalPoolShapeFn); |
1558 | |
1559 | REGISTER_OP("FractionalAvgPoolGrad" ) |
1560 | .Input("orig_input_tensor_shape: int64" ) |
1561 | .Input("out_backprop: T" ) |
1562 | .Input("row_pooling_sequence: int64" ) |
1563 | .Input("col_pooling_sequence: int64" ) |
1564 | .Output("output: T" ) |
1565 | .Attr("overlapping: bool = false" ) |
1566 | .Attr("T: {float, double, int32, int64}" ) |
1567 | .SetShapeFn([](InferenceContext* c) { |
1568 | if (c->input_tensor(0) != nullptr) { |
1569 | ShapeHandle out; |
1570 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); |
1571 | c->set_output(0, out); |
1572 | } else { |
1573 | c->set_output(0, c->UnknownShapeOfRank(4)); |
1574 | } |
1575 | return OkStatus(); |
1576 | }); |
1577 | |
1578 | REGISTER_OP("QuantizedAvgPool" ) |
1579 | .Input("input: T" ) |
1580 | .Input("min_input: float" ) |
1581 | .Input("max_input: float" ) |
1582 | .Output("output: T" ) |
1583 | .Output("min_output: float" ) |
1584 | .Output("max_output: float" ) |
1585 | .Attr("T: quantizedtype" ) |
1586 | .Attr("ksize: list(int)" ) |
1587 | .Attr("strides: list(int)" ) |
1588 | .Attr(GetPaddingAttrString()) |
1589 | .SetShapeFn(shape_inference::QuantizedAvgPoolShape); |
1590 | |
1591 | REGISTER_OP("QuantizedBiasAdd" ) |
1592 | .Input("input: T1" ) |
1593 | .Input("bias: T2" ) |
1594 | .Input("min_input: float" ) |
1595 | .Input("max_input: float" ) |
1596 | .Input("min_bias: float" ) |
1597 | .Input("max_bias: float" ) |
1598 | .Output("output: out_type" ) |
1599 | .Output("min_out: float" ) |
1600 | .Output("max_out: float" ) |
1601 | .Attr("T1: quantizedtype" ) |
1602 | .Attr("T2: quantizedtype" ) |
1603 | .Attr("out_type: quantizedtype" ) |
1604 | .SetShapeFn([](InferenceContext* c) { |
1605 | TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c)); |
1606 | ShapeHandle unused; |
1607 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1608 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1609 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1610 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
1611 | c->set_output(1, c->Scalar()); |
1612 | c->set_output(2, c->Scalar()); |
1613 | return OkStatus(); |
1614 | }); |
1615 | |
1616 | REGISTER_OP("QuantizedConv2D" ) |
1617 | .Input("input: Tinput" ) |
1618 | .Input("filter: Tfilter" ) |
1619 | .Input("min_input: float" ) |
1620 | .Input("max_input: float" ) |
1621 | .Input("min_filter: float" ) |
1622 | .Input("max_filter: float" ) |
1623 | .Output("output: out_type" ) |
1624 | .Output("min_output: float" ) |
1625 | .Output("max_output: float" ) |
1626 | .Attr("Tinput: quantizedtype" ) |
1627 | .Attr("Tfilter: quantizedtype" ) |
1628 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
1629 | .Attr("strides: list(int)" ) |
1630 | .Attr(GetPaddingAttrString()) |
1631 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1632 | .SetShapeFn(shape_inference::QuantizedConv2DShape); |
1633 | |
1634 | REGISTER_OP("QuantizedMaxPool" ) |
1635 | .Input("input: T" ) |
1636 | .Input("min_input: float" ) |
1637 | .Input("max_input: float" ) |
1638 | .Output("output: T" ) |
1639 | .Output("min_output: float" ) |
1640 | .Output("max_output: float" ) |
1641 | .Attr("T: quantizedtype" ) |
1642 | .Attr("ksize: list(int)" ) |
1643 | .Attr("strides: list(int)" ) |
1644 | .Attr(GetPaddingAttrString()) |
1645 | .SetShapeFn([](InferenceContext* c) { |
1646 | TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c)); |
1647 | ShapeHandle unused; |
1648 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1649 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1650 | c->set_output(1, c->Scalar()); |
1651 | c->set_output(2, c->Scalar()); |
1652 | return OkStatus(); |
1653 | }); |
1654 | |
1655 | REGISTER_OP("QuantizedRelu" ) |
1656 | .Input("features: Tinput" ) |
1657 | .Input("min_features: float" ) |
1658 | .Input("max_features: float" ) |
1659 | .Output("activations: out_type" ) |
1660 | .Output("min_activations: float" ) |
1661 | .Output("max_activations: float" ) |
1662 | .Attr("Tinput: quantizedtype" ) |
1663 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
1664 | .SetShapeFn([](InferenceContext* c) { |
1665 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
1666 | ShapeHandle unused; |
1667 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1668 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1669 | c->set_output(1, c->Scalar()); |
1670 | c->set_output(2, c->Scalar()); |
1671 | return OkStatus(); |
1672 | }); |
1673 | |
1674 | REGISTER_OP("QuantizedRelu6" ) |
1675 | .Input("features: Tinput" ) |
1676 | .Input("min_features: float" ) |
1677 | .Input("max_features: float" ) |
1678 | .Output("activations: out_type" ) |
1679 | .Output("min_activations: float" ) |
1680 | .Output("max_activations: float" ) |
1681 | .Attr("Tinput: quantizedtype" ) |
1682 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
1683 | .SetShapeFn([](InferenceContext* c) { |
1684 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
1685 | ShapeHandle unused; |
1686 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1687 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1688 | c->set_output(1, c->Scalar()); |
1689 | c->set_output(2, c->Scalar()); |
1690 | return OkStatus(); |
1691 | }); |
1692 | |
1693 | REGISTER_OP("QuantizedReluX" ) |
1694 | .Input("features: Tinput" ) |
1695 | .Input("max_value: float" ) |
1696 | .Input("min_features: float" ) |
1697 | .Input("max_features: float" ) |
1698 | .Output("activations: out_type" ) |
1699 | .Output("min_activations: float" ) |
1700 | .Output("max_activations: float" ) |
1701 | .Attr("Tinput: quantizedtype" ) |
1702 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
1703 | .SetShapeFn([](InferenceContext* c) { |
1704 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
1705 | ShapeHandle unused; |
1706 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1707 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1708 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1709 | c->set_output(1, c->Scalar()); |
1710 | c->set_output(2, c->Scalar()); |
1711 | return OkStatus(); |
1712 | }); |
1713 | |
1714 | REGISTER_OP("QuantizedBatchNormWithGlobalNormalization" ) |
1715 | .Input("t: Tinput" ) |
1716 | .Input("t_min: float" ) |
1717 | .Input("t_max: float" ) |
1718 | .Input("m: Tinput" ) |
1719 | .Input("m_min: float" ) |
1720 | .Input("m_max: float" ) |
1721 | .Input("v: Tinput" ) |
1722 | .Input("v_min: float" ) |
1723 | .Input("v_max: float" ) |
1724 | .Input("beta: Tinput" ) |
1725 | .Input("beta_min: float" ) |
1726 | .Input("beta_max: float" ) |
1727 | .Input("gamma: Tinput" ) |
1728 | .Input("gamma_min: float" ) |
1729 | .Input("gamma_max: float" ) |
1730 | .Output("result: out_type" ) |
1731 | .Output("result_min: float" ) |
1732 | .Output("result_max: float" ) |
1733 | .Attr("Tinput: quantizedtype" ) |
1734 | .Attr("out_type: quantizedtype" ) |
1735 | .Attr("variance_epsilon: float" ) |
1736 | .Attr("scale_after_normalization: bool" ) |
1737 | .SetShapeFn([](InferenceContext* c) { |
1738 | ShapeHandle input; |
1739 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input)); |
1740 | |
1741 | DimensionHandle last_dim = c->Dim(input, 3); |
1742 | for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma |
1743 | ShapeHandle vec; |
1744 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec)); |
1745 | TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim)); |
1746 | } |
1747 | |
1748 | ShapeHandle out; |
1749 | TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out)); |
1750 | c->set_output(0, out); |
1751 | c->set_output(1, c->Scalar()); |
1752 | c->set_output(2, c->Scalar()); |
1753 | |
1754 | return OkStatus(); |
1755 | }); |
1756 | |
1757 | #ifdef INTEL_MKL |
1758 | REGISTER_OP("_MklDepthwiseConv2dNative" ) |
1759 | .Input("input: T" ) |
1760 | .Input("filter: T" ) |
1761 | .Input("mkl_input: uint8" ) |
1762 | .Input("mkl_filter: uint8" ) |
1763 | .Output("output: T" ) |
1764 | .Output("filter_output: T" ) |
1765 | .Output("mkl_output: uint8" ) |
1766 | .Output("mkl_filter_output: uint8" ) |
1767 | .Attr("T: {half, bfloat16, float, double}" ) |
1768 | .Attr("strides: list(int)" ) |
1769 | .Attr("is_filter_const: bool = false" ) |
1770 | .Attr(GetPaddingAttrStringWithExplicit()) |
1771 | .Attr(GetConvnetDataFormatAttrString()) |
1772 | .Attr(GetExplicitPaddingsAttrString()) |
1773 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1774 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShapeWithExplicitPadding); |
1775 | |
1776 | REGISTER_OP("_MklConv2D" ) |
1777 | .Input("input: T" ) |
1778 | .Input("filter: T" ) |
1779 | .Input("mkl_input: uint8" ) |
1780 | .Input("mkl_filter: uint8" ) |
1781 | .Output("output: T" ) |
1782 | .Output("filter_output: T" ) |
1783 | .Output("mkl_output: uint8" ) |
1784 | .Output("mkl_filter_output: uint8" ) |
1785 | .Attr("T: {bfloat16, float}" ) |
1786 | .Attr("strides: list(int)" ) |
1787 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1788 | .Attr("is_filter_const: bool = false" ) |
1789 | .Attr(GetPaddingAttrStringWithExplicit()) |
1790 | .Attr(GetConvnetDataFormatAttrString()) |
1791 | .Attr(GetExplicitPaddingsAttrString()) |
1792 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1793 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) |
1794 | .Doc(R"doc( |
1795 | MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution. |
1796 | |
1797 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
1798 | expected to invoke these operators. |
1799 | )doc" ); |
1800 | |
1801 | REGISTER_OP("_MklNativeConv2D" ) |
1802 | .Input("input: T" ) |
1803 | .Input("filter: T" ) |
1804 | .Output("output: T" ) |
1805 | .Attr("T: {bfloat16, float}" ) |
1806 | .Attr("strides: list(int)" ) |
1807 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1808 | .Attr("is_filter_const: bool = false" ) |
1809 | .Attr(GetPaddingAttrStringWithExplicit()) |
1810 | .Attr(GetExplicitPaddingsAttrString()) |
1811 | .Attr(GetConvnetDataFormatAttrString()) |
1812 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1813 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) |
1814 | .Doc(R"doc( |
1815 | MKL version of Conv2D operator for Eager mode. Uses MKL DNN APIs to perform 2D convolution. |
1816 | |
1817 | NOTE Do not invoke this operator directly in Python. Eager Op rewrite is |
1818 | expected to invoke these operators. |
1819 | )doc" ); |
1820 | |
1821 | REGISTER_OP("__MklDummyConv2DWithBias" ) |
1822 | .Input("input: T" ) |
1823 | .Input("filter: T" ) |
1824 | .Input("bias: T" ) |
1825 | .Output("output: T" ) |
1826 | .Attr("T: {bfloat16, float}" ) |
1827 | .Attr("strides: list(int)" ) |
1828 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1829 | .Attr("is_filter_const: bool = false" ) |
1830 | .Attr(GetPaddingAttrStringWithExplicit()) |
1831 | .Attr(GetExplicitPaddingsAttrString()) |
1832 | .Attr(GetConvnetDataFormatAttrString()) |
1833 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1834 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) |
1835 | .Doc(R"doc( |
1836 | Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node |
1837 | does not perform anything. It is just created as an intermediate output of |
1838 | merging Conv2D and BiasAdd. |
1839 | |
1840 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
1841 | expected to invoke these operators. |
1842 | )doc" ); |
1843 | |
1844 | REGISTER_OP("_MklConv2DWithBias" ) |
1845 | .Input("input: T" ) |
1846 | .Input("filter: T" ) |
1847 | .Input("bias: T" ) |
1848 | .Input("mkl_input: uint8" ) |
1849 | .Input("mkl_filter: uint8" ) |
1850 | .Input("mkl_bias: uint8" ) |
1851 | .Output("output: T" ) |
1852 | .Output("filter_output: T" ) |
1853 | .Output("mkl_output: uint8" ) |
1854 | .Output("mkl_filter_output: uint8" ) |
1855 | .Attr("T: {bfloat16, float}" ) |
1856 | .Attr("strides: list(int)" ) |
1857 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1858 | .Attr("is_filter_const: bool = false" ) |
1859 | .Attr(GetPaddingAttrStringWithExplicit()) |
1860 | .Attr(GetExplicitPaddingsAttrString()) |
1861 | .Attr(GetConvnetDataFormatAttrString()) |
1862 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1863 | .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding) |
1864 | .Doc(R"doc( |
1865 | MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform |
1866 | 2D convolution and add Bias to the output of convolution. |
1867 | |
1868 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
1869 | expected to invoke these operators. |
1870 | )doc" ); |
1871 | |
1872 | REGISTER_OP("__MklDummyPadWithConv2D" ) |
1873 | .Input("input: T" ) |
1874 | .Input("filter: T" ) |
1875 | .Input("paddings: Tpaddings" ) |
1876 | .Output("output: T" ) |
1877 | .Attr("T: {bfloat16, float}" ) |
1878 | .Attr("strides: list(int)" ) |
1879 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1880 | .Attr(GetPaddingAttrString()) |
1881 | .Attr(GetConvnetDataFormatAttrString()) |
1882 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1883 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
1884 | .SetShapeFn(shape_inference::Conv2DShape) |
1885 | .Doc(R"doc( |
1886 | Dummy node that enables fusing Pad and Conv2D operator for MKL. This node |
1887 | does not perform anything. It is just created as an intermediate output of |
1888 | merging Pad and Conv2D. |
1889 | |
1890 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
1891 | expected to invoke these operators. |
1892 | )doc" ); |
1893 | |
1894 | REGISTER_OP("_MklPadWithConv2D" ) |
1895 | .Input("input: T" ) |
1896 | .Input("filter: T" ) |
1897 | .Input("paddings: Tpaddings" ) |
1898 | .Input("mkl_input: uint8" ) |
1899 | .Input("mkl_filter: uint8" ) |
1900 | .Input("mkl_paddings: uint8" ) |
1901 | .Output("output: T" ) |
1902 | .Output("filter_output: T" ) |
1903 | .Output("mkl_output: uint8" ) |
1904 | .Output("mkl_filter_output: uint8" ) |
1905 | .Attr("T: {bfloat16, float}" ) |
1906 | .Attr("strides: list(int)" ) |
1907 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1908 | .Attr(GetPaddingAttrString()) |
1909 | .Attr(GetConvnetDataFormatAttrString()) |
1910 | .Attr("is_filter_const: bool = false" ) |
1911 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1912 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
1913 | .SetShapeFn(shape_inference::Conv2DShape) |
1914 | .Doc(R"doc( |
1915 | MKL version of Pad and Conv2D operator. Uses MKL DNN APIs to perform |
1916 | Pad and 2D convolution to the output of convolution. |
1917 | |
1918 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
1919 | expected to invoke these operators. |
1920 | )doc" ); |
1921 | |
1922 | REGISTER_OP("_MklConv2DBackpropFilter" ) |
1923 | .Input("input: T" ) |
1924 | .Input("filter_sizes: int32" ) |
1925 | .Input("out_backprop: T" ) |
1926 | .Input("mkl_input: uint8" ) |
1927 | .Input("mkl_filter_size: uint8" ) |
1928 | .Input("mkl_out_backprop: uint8" ) |
1929 | .Output("output: T" ) |
1930 | .Output("mkl_output: uint8" ) |
1931 | .Attr("T: {bfloat16, float}" ) |
1932 | .Attr("strides: list(int)" ) |
1933 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1934 | .Attr(GetPaddingAttrString()) |
1935 | .Attr(GetConvnetDataFormatAttrString()) |
1936 | .Attr(GetExplicitPaddingsAttrString()) |
1937 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1938 | .SetShapeFn([](InferenceContext* c) { |
1939 | ShapeHandle s; |
1940 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
1941 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
1942 | c->set_output(0, s); |
1943 | return Status::OK(); |
1944 | }) |
1945 | .Doc(R"doc( |
1946 | MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the |
1947 | gradients of convolution with respect to the filter. |
1948 | |
1949 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
1950 | expected to invoke these operators. |
1951 | )doc" ); |
1952 | |
1953 | REGISTER_OP("_MklNativeConv2DBackpropFilter" ) |
1954 | .Input("input: T" ) |
1955 | .Input("filter_sizes: int32" ) |
1956 | .Input("out_backprop: T" ) |
1957 | .Output("output: T" ) |
1958 | .Attr("T: {bfloat16, float}" ) |
1959 | .Attr("strides: list(int)" ) |
1960 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1961 | .Attr(GetPaddingAttrStringWithExplicit()) |
1962 | .Attr(GetExplicitPaddingsAttrString()) |
1963 | .Attr(GetConvnetDataFormatAttrString()) |
1964 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1965 | .SetShapeFn([](InferenceContext* c) { |
1966 | ShapeHandle s; |
1967 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
1968 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
1969 | c->set_output(0, s); |
1970 | return Status::OK(); |
1971 | }) |
1972 | .Doc(R"doc( |
1973 | MKL version of Conv2DBackpropFilter for Eager mode. Uses MKL DNN APIs |
1974 | to compute the gradients of convolution with respect to the filter. |
1975 | |
1976 | NOTE Do not invoke this operator directly in Python. Eager Op rewrite pass is |
1977 | expected to invoke these operators. |
1978 | )doc" ); |
1979 | |
1980 | REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias" ) |
1981 | .Input("input: T" ) |
1982 | .Input("filter_sizes: int32" ) |
1983 | .Input("out_backprop: T" ) |
1984 | .Output("output: T" ) |
1985 | .Output("bias_grad: T" ) |
1986 | .Attr("T: {bfloat16, float}" ) |
1987 | .Attr("strides: list(int)" ) |
1988 | .Attr("use_cudnn_on_gpu: bool = true" ) |
1989 | .Attr(GetPaddingAttrString()) |
1990 | .Attr(GetConvnetDataFormatAttrString()) |
1991 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
1992 | .SetShapeFn([](InferenceContext* c) { |
1993 | ShapeHandle input_shape; |
1994 | // Fetch the data_format attribute, which may not exist. |
1995 | string data_format; |
1996 | Status s = c->GetAttr("data_format" , &data_format); |
1997 | |
1998 | if (s.ok() && data_format == "NCHW" ) { |
1999 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
2000 | c->set_output(1, c->Vector(c->Dim(input_shape, -3))); |
2001 | } else { |
2002 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); |
2003 | c->set_output(1, c->Vector(c->Dim(input_shape, -1))); |
2004 | } |
2005 | ShapeHandle sh; |
2006 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh)); |
2007 | TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh)); |
2008 | c->set_output(0, sh); |
2009 | return Status::OK(); |
2010 | }) |
2011 | .Doc(R"doc( |
2012 | Dummy node that enables fusing Conv2DBackpropFilter and BiasAddGrad operator |
2013 | for MKL. This node does not perform anything. It is just created as an |
2014 | intermediate output of merging Conv2DBackpropFilter and BiasAddGrad. |
2015 | |
2016 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2017 | expected to invoke these operators. |
2018 | )doc" ); |
2019 | |
2020 | REGISTER_OP("_MklConv2DBackpropFilterWithBias" ) |
2021 | .Input("input: T" ) |
2022 | .Input("filter_sizes: int32" ) |
2023 | .Input("out_backprop: T" ) |
2024 | .Input("mkl_input: uint8" ) |
2025 | .Input("mkl_filter_size: uint8" ) |
2026 | .Input("mkl_out_backprop: uint8" ) |
2027 | .Output("output: T" ) |
2028 | .Output("bias_grad: T" ) |
2029 | .Output("mkl_output: uint8" ) |
2030 | .Output("mkl_bias_grad: uint8" ) |
2031 | .Attr("T: {bfloat16, float}" ) |
2032 | .Attr("strides: list(int)" ) |
2033 | .Attr("use_cudnn_on_gpu: bool = true" ) |
2034 | .Attr(GetPaddingAttrString()) |
2035 | .Attr(GetConvnetDataFormatAttrString()) |
2036 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2037 | .SetShapeFn(shape_inference::Conv2DBackpropFilterWithBiasShape) |
2038 | .Doc(R"doc( |
2039 | MKL version of Conv2DBackpropFilterWithBias. Uses MKL DNN APIs to compute the |
2040 | gradients of convolution with respect to the filter. |
2041 | |
2042 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2043 | expected to invoke these operators. |
2044 | )doc" ); |
2045 | |
2046 | #ifdef INTEL_MKL_ML_ONLY |
2047 | REGISTER_OP("_MklConv2DWithBiasBackpropBias" ) |
2048 | .Input("out_backprop: T" ) |
2049 | .Input("mkl_out_backprop: uint8" ) |
2050 | .Output("output: T" ) |
2051 | .Output("mkl_output: uint8" ) |
2052 | .Attr("T: {half, float, double}" ) |
2053 | .Attr("strides: list(int)" ) |
2054 | .Attr(GetConvnetDataFormatAttrString()) |
2055 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2056 | .Doc(R"doc( |
2057 | MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the |
2058 | gradients of convolution with respect to the bias. |
2059 | |
2060 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2061 | expected to invoke these operators. |
2062 | )doc" ); |
2063 | #endif |
2064 | |
2065 | REGISTER_OP("_MklConv2DBackpropInput" ) |
2066 | .Input("input_sizes: int32" ) |
2067 | .Input("filter: T" ) |
2068 | .Input("out_backprop: T" ) |
2069 | .Input("mkl_input_sizes: uint8" ) |
2070 | .Input("mkl_filter: uint8" ) |
2071 | .Input("mkl_out_backprop: uint8" ) |
2072 | .Output("output: T" ) |
2073 | .Output("mkl_output: uint8" ) |
2074 | .Attr("T: {bfloat16, float}" ) |
2075 | .Attr("strides: list(int)" ) |
2076 | .Attr("use_cudnn_on_gpu: bool = true" ) |
2077 | .Attr(GetPaddingAttrString()) |
2078 | .Attr(GetConvnetDataFormatAttrString()) |
2079 | .Attr(GetExplicitPaddingsAttrString()) |
2080 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2081 | .SetShapeFn([](InferenceContext* c) { |
2082 | ShapeHandle s; |
2083 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
2084 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
2085 | c->set_output(0, s); |
2086 | return Status::OK(); |
2087 | }) |
2088 | .Doc(R"doc( |
2089 | MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the |
2090 | gradients of convolution with respect to the input. |
2091 | |
2092 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2093 | expected to invoke these operators. |
2094 | )doc" ); |
2095 | |
2096 | REGISTER_OP("_MklNativeConv2DBackpropInput" ) |
2097 | .Input("input_sizes: int32" ) |
2098 | .Input("filter: T" ) |
2099 | .Input("out_backprop: T" ) |
2100 | .Output("output: T" ) |
2101 | .Attr("T: {bfloat16, float}" ) |
2102 | .Attr("strides: list(int)" ) |
2103 | .Attr("use_cudnn_on_gpu: bool = true" ) |
2104 | .Attr(GetPaddingAttrStringWithExplicit()) |
2105 | .Attr(GetExplicitPaddingsAttrString()) |
2106 | .Attr(GetConvnetDataFormatAttrString()) |
2107 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2108 | .SetShapeFn([](InferenceContext* c) { |
2109 | ShapeHandle s; |
2110 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
2111 | TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s)); |
2112 | c->set_output(0, s); |
2113 | return Status::OK(); |
2114 | }) |
2115 | .Doc(R"doc( |
2116 | MKL version of Convolution2D backward input for Eager mode. Uses MKL DNN APIs |
2117 | to compute the gradients of convolution with respect to the input. |
2118 | |
2119 | NOTE Do not invoke this operator directly in Python. Eager op rewrite is |
2120 | expected to invoke these operators. |
2121 | )doc" ); |
2122 | |
2123 | REGISTER_OP("_MklConv3D" ) |
2124 | .Input("input: T" ) |
2125 | .Input("filter: T" ) |
2126 | .Input("mkl_input: uint8" ) |
2127 | .Input("mkl_filter: uint8" ) |
2128 | .Output("output: T" ) |
2129 | .Output("filter_output: T" ) |
2130 | .Output("mkl_output: uint8" ) |
2131 | .Output("mkl_filter_output: uint8" ) |
2132 | .Attr("T: {bfloat16, float}" ) |
2133 | .Attr("strides: list(int) >= 5" ) |
2134 | .Attr("is_filter_const: bool = false" ) |
2135 | .Attr(GetPaddingAttrString()) |
2136 | .Attr(GetConvnet3dDataFormatAttrString()) |
2137 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
2138 | .SetShapeFn(shape_inference::Conv3DShape) |
2139 | .Doc(R"doc( |
2140 | MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution. |
2141 | |
2142 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2143 | expected to invoke these operators. |
2144 | )doc" ); |
2145 | |
2146 | REGISTER_OP("_MklConv3DBackpropInputV2" ) |
2147 | .Input("input_sizes: Tshape" ) |
2148 | .Input("filter: T" ) |
2149 | .Input("out_backprop: T" ) |
2150 | .Input("mkl_input_sizes: uint8" ) |
2151 | .Input("mkl_filter: uint8" ) |
2152 | .Input("mkl_out_backprop: uint8" ) |
2153 | .Output("output: T" ) |
2154 | .Output("mkl_output: uint8" ) |
2155 | .Attr("T: {bfloat16, float}" ) |
2156 | .Attr("strides: list(int) >= 5" ) |
2157 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
2158 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
2159 | .Attr(GetPaddingAttrString()) |
2160 | .Attr(GetConvnet3dDataFormatAttrString()) |
2161 | .SetShapeFn([](InferenceContext* c) { |
2162 | ShapeHandle s; |
2163 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s)); |
2164 | TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); |
2165 | c->set_output(0, s); |
2166 | return Status::OK(); |
2167 | }) |
2168 | .Doc(R"doc( |
2169 | MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the |
2170 | gradients of convolution with respect to the input. |
2171 | |
2172 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2173 | expected to invoke these operators. |
2174 | )doc" ); |
2175 | |
2176 | REGISTER_OP("_MklConv3DBackpropFilterV2" ) |
2177 | .Input("input: T" ) |
2178 | .Input("filter_sizes: int32" ) |
2179 | .Input("out_backprop: T" ) |
2180 | .Input("mkl_input: uint8" ) |
2181 | .Input("mkl_filter_size: uint8" ) |
2182 | .Input("mkl_out_backprop: uint8" ) |
2183 | .Output("output: T" ) |
2184 | .Output("mkl_output: uint8" ) |
2185 | .Attr("T: {bfloat16, float}" ) |
2186 | .Attr("strides: list(int)" ) |
2187 | .Attr(GetPaddingAttrString()) |
2188 | .Attr(GetConvnet3dDataFormatAttrString()) |
2189 | .Attr("dilations: list(int) = [1, 1, 1, 1, 1]" ) |
2190 | .SetShapeFn([](InferenceContext* c) { |
2191 | ShapeHandle s; |
2192 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s)); |
2193 | TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s)); |
2194 | c->set_output(0, s); |
2195 | return Status::OK(); |
2196 | }) |
2197 | .Doc(R"doc( |
2198 | MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the |
2199 | gradients of convolution with respect to the filter. |
2200 | |
2201 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2202 | expected to invoke these operators. |
2203 | )doc" ); |
2204 | |
2205 | REGISTER_OP("_MklRelu" ) |
2206 | .Input("features: T" ) |
2207 | .Input("mkl_features: uint8" ) |
2208 | .Output("activations: T" ) |
2209 | .Output("mkl_activations: uint8" ) |
2210 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
2211 | .SetShapeFn(shape_inference::UnchangedShape) |
2212 | .Doc(R"doc( |
2213 | MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator. |
2214 | |
2215 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2216 | expected to invoke these operators. |
2217 | )doc" ); |
2218 | |
2219 | REGISTER_OP("_MklReluGrad" ) |
2220 | .Input("gradients: T" ) |
2221 | .Input("features: T" ) |
2222 | .Input("mkl_gradients: uint8" ) |
2223 | .Input("mkl_features: uint8" ) |
2224 | .Output("backprops: T" ) |
2225 | .Output("mkl_backprops: uint8" ) |
2226 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
2227 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn) |
2228 | .Doc(R"doc( |
2229 | MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified |
2230 | linear gradients for Relu operation. |
2231 | |
2232 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2233 | expected to invoke these operators. |
2234 | )doc" ); |
2235 | |
2236 | REGISTER_OP("_MklRelu6" ) |
2237 | .Input("features: T" ) |
2238 | .Input("mkl_features: uint8" ) |
2239 | .Output("activations: T" ) |
2240 | .Output("mkl_activations: uint8" ) |
2241 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
2242 | .SetShapeFn(shape_inference::UnchangedShape) |
2243 | .Doc(R"doc( |
2244 | MKL version of Relu6 operator. Uses MKL DNN APIs to implement Relu6 operator. |
2245 | |
2246 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2247 | expected to invoke these operators. |
2248 | )doc" ); |
2249 | |
2250 | REGISTER_OP("_MklRelu6Grad" ) |
2251 | .Input("gradients: T" ) |
2252 | .Input("features: T" ) |
2253 | .Input("mkl_gradients: uint8" ) |
2254 | .Input("mkl_features: uint8" ) |
2255 | .Output("backprops: T" ) |
2256 | .Output("mkl_backprops: uint8" ) |
2257 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
2258 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn) |
2259 | .Doc(R"doc( |
2260 | MKL version of Relu6Grad operator. Uses MKL DNN APIs to compute rectified |
2261 | linear gradients for Relu6 operation. |
2262 | |
2263 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2264 | expected to invoke these operators. |
2265 | )doc" ); |
2266 | |
2267 | REGISTER_OP("_MklLeakyRelu" ) |
2268 | .Input("features: T" ) |
2269 | .Input("mkl_features: uint8" ) |
2270 | .Output("activations: T" ) |
2271 | .Output("mkl_activations: uint8" ) |
2272 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
2273 | .Attr("alpha: float = 0.2" ) |
2274 | .SetShapeFn(shape_inference::UnchangedShape) |
2275 | .Doc(R"doc( |
2276 | MKL version of LeakyRelu operator. Uses MKL DNN APIs to implement |
2277 | LeakyRelu operator. |
2278 | |
2279 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2280 | expected to invoke these operators. |
2281 | )doc" ); |
2282 | |
2283 | REGISTER_OP("_MklLeakyReluGrad" ) |
2284 | .Input("gradients: T" ) |
2285 | .Input("features: T" ) |
2286 | .Input("mkl_gradients: uint8" ) |
2287 | .Input("mkl_features: uint8" ) |
2288 | .Output("backprops: T" ) |
2289 | .Output("mkl_backprops: uint8" ) |
2290 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
2291 | .Attr("alpha: float = 0.2" ) |
2292 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn) |
2293 | .Doc(R"doc( |
2294 | MKL version of LeakyReluGrad operator. Uses MKL DNN APIs to compute rectified |
2295 | linear gradients for LeakyReluGrad operation. |
2296 | |
2297 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2298 | expected to invoke these operators. |
2299 | )doc" ); |
2300 | |
2301 | REGISTER_OP("_MklElu" ) |
2302 | .Input("features: T" ) |
2303 | .Input("mkl_features: uint8" ) |
2304 | .Output("activations: T" ) |
2305 | .Output("mkl_activations: uint8" ) |
2306 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
2307 | .SetShapeFn(shape_inference::UnchangedShape) |
2308 | .Doc(R"doc( |
2309 | MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator. |
2310 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2311 | expected to invoke these operators. |
2312 | )doc" ); |
2313 | |
2314 | REGISTER_OP("_MklEluGrad" ) |
2315 | .Input("gradients: T" ) |
2316 | .Input("features: T" ) |
2317 | .Input("mkl_gradients: uint8" ) |
2318 | .Input("mkl_features: uint8" ) |
2319 | .Output("backprops: T" ) |
2320 | .Output("mkl_backprops: uint8" ) |
2321 | .Attr("T: {float, bfloat16} = DT_FLOAT" ) |
2322 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn) |
2323 | .Doc(R"doc( |
2324 | MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu |
2325 | gradients for Elu operation. |
2326 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2327 | expected to invoke these operators. |
2328 | )doc" ); |
2329 | |
2330 | REGISTER_OP("_MklSoftmax" ) |
2331 | .Input("logits: T" ) |
2332 | .Input("mkl_logits: uint8" ) |
2333 | .Output("softmax: T" ) |
2334 | .Output("mkl_softmax: uint8" ) |
2335 | .Attr("T: {bfloat16, half, float, double}" ) |
2336 | .SetShapeFn([](InferenceContext* c) { |
2337 | return shape_inference::UnchangedShapeWithRankAtLeast(c, 1); |
2338 | }) |
2339 | .Doc(R"doc( |
2340 | MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified |
2341 | linear gradients for Relu operation. |
2342 | )doc" ); |
2343 | |
2344 | REGISTER_OP("_MklTanh" ) |
2345 | .Input("features: T" ) |
2346 | .Input("mkl_features: uint8" ) |
2347 | .Output("activations: T" ) |
2348 | .Output("mkl_activations: uint8" ) |
2349 | .Attr("T: realnumbertype" ) |
2350 | .SetShapeFn(shape_inference::UnchangedShape) |
2351 | .Doc(R"doc( |
2352 | MKL version of Tanh operator. Uses MKL DNN APIs to implement Tanh operator. |
2353 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2354 | expected to invoke these operators. |
2355 | )doc" ); |
2356 | |
2357 | REGISTER_OP("_MklTanhGrad" ) |
2358 | .Input("gradients: T" ) |
2359 | .Input("features: T" ) |
2360 | .Input("mkl_gradients: uint8" ) |
2361 | .Input("mkl_features: uint8" ) |
2362 | .Output("backprops: T" ) |
2363 | .Output("mkl_backprops: uint8" ) |
2364 | .Attr("T: realnumbertype" ) |
2365 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn) |
2366 | .Doc(R"doc( |
2367 | MKL version of TanhGrad operator. Uses MKL DNN APIs to compute tanh |
2368 | gradients for Tanh operation. |
2369 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2370 | expected to invoke these operators. |
2371 | )doc" ); |
2372 | |
2373 | REGISTER_OP("_MklMaxPool" ) |
2374 | .Attr("T: {float, half, bfloat16} = DT_FLOAT" ) |
2375 | .Attr("ksize: list(int) >= 4" ) |
2376 | .Attr("strides: list(int) >= 4" ) |
2377 | .Attr(GetPaddingAttrString()) |
2378 | .Attr(GetConvnetDataFormatAttrString()) |
2379 | .Attr(GetExplicitPaddingsAttrString()) |
2380 | .Attr("workspace_enabled: bool = false" ) |
2381 | .Input("input: T" ) |
2382 | .Input("mkl_input: uint8" ) |
2383 | .Output("output: T" ) |
2384 | #ifdef INTEL_MKL_ML_ONLY |
2385 | .Output("workspace: T" ) |
2386 | #else |
2387 | .Output("workspace: uint8" ) |
2388 | #endif |
2389 | .Output("mkl_output: uint8" ) |
2390 | .Output("mkl_workspace: uint8" ) |
2391 | .SetShapeFn(shape_inference::MaxPoolShape) |
2392 | .Doc(R"doc( |
2393 | MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling |
2394 | on the input. |
2395 | |
2396 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2397 | expected to invoke these operators. |
2398 | )doc" ); |
2399 | |
2400 | REGISTER_OP("_MklMaxPoolGrad" ) |
2401 | .Attr("T: {float, half, bfloat16} = DT_FLOAT" ) |
2402 | .Attr("ksize: list(int) >= 4" ) |
2403 | .Attr("strides: list(int) >= 4" ) |
2404 | .Attr("workspace_enabled: bool = false" ) |
2405 | .Attr(GetPaddingAttrString()) |
2406 | .Attr(GetConvnetDataFormatAttrString()) |
2407 | .Attr(GetExplicitPaddingsAttrString()) |
2408 | .Input("orig_input: T" ) |
2409 | .Input("orig_output: T" ) |
2410 | .Input("grad: T" ) |
2411 | #ifdef INTEL_MKL_ML_ONLY |
2412 | .Input("workspace: T" ) |
2413 | #else |
2414 | .Input("workspace: uint8" ) |
2415 | #endif |
2416 | .Input("mkl_orig_input: uint8" ) |
2417 | .Input("mkl_orig_output: uint8" ) |
2418 | .Input("mkl_grad: uint8" ) |
2419 | .Input("mkl_workspace: uint8" ) |
2420 | .Output("output: T" ) |
2421 | .Output("mkl_output: uint8" ) |
2422 | .SetShapeFn(shape_inference::MaxPoolGradShape) |
2423 | .Doc(R"doc( |
2424 | oneDNN version of MaxPoolGrad. Uses oneDNN APIs to compute gradients of |
2425 | MaxPool operator. |
2426 | |
2427 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
2428 | expected to invoke these operators. |
2429 | )doc" ); |
2430 | |
2431 | REGISTER_OP("_MklAvgPool" ) |
2432 | .Input("value: T" ) |
2433 | .Input("mkl_input: uint8" ) |
2434 | .Output("output: T" ) |
2435 | .Output("mkl_output: uint8" ) |
2436 | .Attr("ksize: list(int) >= 4" ) |
2437 | .Attr("strides: list(int) >= 4" ) |
2438 | .Attr(GetPaddingAttrString()) |
2439 | .Attr(GetConvnetDataFormatAttrString()) |
2440 | .Attr("T: {float, half, double, bfloat16}" ) |
2441 | .SetShapeFn(shape_inference::AvgPoolShape) |
2442 | .Doc(R"doc( |
2443 | MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling |
2444 | on the input. |
2445 | |
2446 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2447 | expected to invoke these operators. |
2448 | )doc" ); |
2449 | |
2450 | REGISTER_OP("_MklAvgPoolGrad" ) |
2451 | .Input("orig_input_shape: int32" ) |
2452 | .Input("grad: T" ) |
2453 | .Input("mkl_orig_input: uint8" ) |
2454 | .Input("mkl_grad: uint8" ) |
2455 | .Output("output: T" ) |
2456 | .Output("mkl_output: uint8" ) |
2457 | .Attr("ksize: list(int) >= 4" ) |
2458 | .Attr("strides: list(int) >= 4" ) |
2459 | .Attr(GetPaddingAttrString()) |
2460 | .Attr(GetConvnetDataFormatAttrString()) |
2461 | .Attr("T: {float, half, double, bfloat16}" ) |
2462 | .SetShapeFn(shape_inference::AvgPoolGradShape) |
2463 | .Doc(R"doc( |
2464 | oneDNN version of AvgPoolGrad operator. Uses oneDNN APIs to compute gradients |
2465 | of AvgPool function. |
2466 | |
2467 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
2468 | expected to invoke these operators. |
2469 | )doc" ); |
2470 | |
2471 | REGISTER_OP("_MklAvgPool3D" ) |
2472 | .Input("value: T" ) |
2473 | .Input("mkl_input: uint8" ) |
2474 | .Output("output: T" ) |
2475 | .Output("mkl_output: uint8" ) |
2476 | .Attr("ksize: list(int) >= 5" ) |
2477 | .Attr("strides: list(int) >= 5" ) |
2478 | .Attr(GetPaddingAttrString()) |
2479 | .Attr(GetConvnet3dDataFormatAttrString()) |
2480 | .Attr("T: {float, half, double, bfloat16}" ) |
2481 | .SetShapeFn(shape_inference::Pool3DShape) |
2482 | .Doc(R"doc( |
2483 | MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling |
2484 | on the input. |
2485 | |
2486 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2487 | expected to invoke these operators. |
2488 | )doc" ); |
2489 | |
2490 | REGISTER_OP("_MklAvgPool3DGrad" ) |
2491 | .Input("orig_input_shape: int32" ) |
2492 | .Input("grad: T" ) |
2493 | .Input("mkl_orig_input: uint8" ) |
2494 | .Input("mkl_grad: uint8" ) |
2495 | .Output("output: T" ) |
2496 | .Output("mkl_output: uint8" ) |
2497 | .Attr("ksize: list(int) >= 5" ) |
2498 | .Attr("strides: list(int) >= 5" ) |
2499 | .Attr(GetPaddingAttrString()) |
2500 | .Attr(GetConvnet3dDataFormatAttrString()) |
2501 | .Attr("T: {float, half, double, bfloat16}" ) |
2502 | .SetShapeFn(shape_inference::AvgPool3DGradShape) |
2503 | .Doc(R"doc( |
2504 | oneDNN version of AvgPool3DGrad operator. Uses oneDNN APIs to compute gradients |
2505 | of AvgPool function. |
2506 | |
2507 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
2508 | expected to invoke these operators. |
2509 | )doc" ); |
2510 | |
2511 | REGISTER_OP("_MklMaxPool3D" ) |
2512 | .Input("input: T" ) |
2513 | .Input("mkl_input: uint8" ) |
2514 | .Output("output: T" ) |
2515 | .Output("workspace: uint8" ) |
2516 | .Output("mkl_output: uint8" ) |
2517 | .Output("mkl_workspace: uint8" ) |
2518 | .Attr("ksize: list(int) >= 5" ) |
2519 | .Attr("strides: list(int) >= 5" ) |
2520 | .Attr(GetPaddingAttrString()) |
2521 | .Attr(GetConvnet3dDataFormatAttrString()) |
2522 | .Attr("T: {half, bfloat16, float}" ) |
2523 | .Attr("workspace_enabled: bool = false" ) |
2524 | .SetShapeFn(shape_inference::Pool3DShape) |
2525 | .Doc(R"doc( |
2526 | MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling |
2527 | on the input. |
2528 | |
2529 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2530 | expected to invoke these operators. |
2531 | )doc" ); |
2532 | |
2533 | REGISTER_OP("_MklMaxPool3DGrad" ) |
2534 | .Input("orig_input: TInput" ) |
2535 | .Input("orig_output: TInput" ) |
2536 | .Input("grad: T" ) |
2537 | .Input("workspace: uint8" ) |
2538 | .Input("mkl_orig_input: uint8" ) |
2539 | .Input("mkl_orig_output: uint8" ) |
2540 | .Input("mkl_grad: uint8" ) |
2541 | .Input("mkl_workspace: uint8" ) |
2542 | .Output("output: T" ) |
2543 | .Output("mkl_output: uint8" ) |
2544 | .Attr("ksize: list(int) >= 5" ) |
2545 | .Attr("strides: list(int) >= 5" ) |
2546 | .Attr(GetPaddingAttrString()) |
2547 | .Attr(GetConvnet3dDataFormatAttrString()) |
2548 | .Attr("T: {half, bfloat16, float} = DT_FLOAT" ) |
2549 | .Attr("TInput: {half, bfloat16, float} = DT_FLOAT" ) |
2550 | .Attr("workspace_enabled: bool = false" ) |
2551 | .SetShapeFn(shape_inference::MaxPool3DGradShape) |
2552 | .Doc(R"doc( |
2553 | oneDNN version of MaxPool3DGrad operator. Uses oneDNN APIs to compute gradients |
2554 | of MaxPool3D function. |
2555 | |
2556 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
2557 | expected to invoke these operators. |
2558 | )doc" ); |
2559 | |
2560 | REGISTER_OP("_MklLRN" ) |
2561 | .Input("input: T" ) |
2562 | .Input("mkl_input: uint8" ) |
2563 | .Output("output: T" ) |
2564 | .Output("workspace: uint8" ) |
2565 | .Output("mkl_output: uint8" ) |
2566 | .Output("mkl_workspace: uint8" ) |
2567 | .Attr("depth_radius: int = 5" ) |
2568 | .Attr("bias: float = 1.0" ) |
2569 | .Attr("alpha: float = 1.0" ) |
2570 | .Attr("beta: float = 0.5" ) |
2571 | .Attr("workspace_enabled: bool = false" ) |
2572 | .Attr("T: {float, half} = DT_FLOAT" ) |
2573 | .SetShapeFn([](InferenceContext* c) { |
2574 | return UnchangedShapeWithRank(c, 4); |
2575 | }) |
2576 | .Doc(R"doc( |
2577 | MKL version of LRN operator. Uses MKL DNN APIs to perform local response |
2578 | normalization. |
2579 | |
2580 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2581 | expected to invoke these operators. |
2582 | )doc" ); |
2583 | |
2584 | REGISTER_OP("_MklLRNGrad" ) |
2585 | .Input("input_grads: T" ) |
2586 | .Input("input_image: T" ) |
2587 | .Input("output_image: T" ) |
2588 | .Input("workspace: uint8" ) |
2589 | .Input("mkl_input_grads: uint8" ) |
2590 | .Input("mkl_input_image: uint8" ) |
2591 | .Input("mkl_output_image: uint8" ) |
2592 | .Input("mkl_workspace: uint8" ) |
2593 | .Output("output: T" ) |
2594 | .Output("mkl_output: uint8" ) |
2595 | .Attr("depth_radius: int = 5" ) |
2596 | .Attr("bias: float = 1.0" ) |
2597 | .Attr("alpha: float = 1.0" ) |
2598 | .Attr("beta: float = 0.5" ) |
2599 | .Attr("workspace_enabled: bool = false" ) |
2600 | .Attr("T: {float, half} = DT_FLOAT" ) |
2601 | .SetShapeFn([](InferenceContext* c) { |
2602 | ShapeHandle s; |
2603 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads |
2604 | TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image |
2605 | TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image |
2606 | c->set_output(0, s); |
2607 | return Status::OK(); |
2608 | }) |
2609 | .Doc(R"doc( |
2610 | MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for |
2611 | local response normalization. |
2612 | |
2613 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2614 | expected to invoke these operators. |
2615 | )doc" ); |
2616 | |
2617 | REGISTER_OP("_MklFusedBatchNorm" ) |
2618 | .Input("x: T" ) |
2619 | .Input("scale: T" ) |
2620 | .Input("offset: T" ) |
2621 | .Input("mean: T" ) |
2622 | .Input("variance: T" ) |
2623 | .Input("mkl_x: uint8" ) |
2624 | .Input("mkl_scale: uint8" ) |
2625 | .Input("mkl_offset: uint8" ) |
2626 | .Input("mkl_mean: uint8" ) |
2627 | .Input("mkl_variance: uint8" ) |
2628 | .Output("y: T" ) |
2629 | .Output("batch_mean: T" ) |
2630 | .Output("batch_variance: T" ) |
2631 | .Output("reserve_space_1: T" ) |
2632 | .Output("reserve_space_2: T" ) |
2633 | .Output("mkl_y: uint8" ) |
2634 | .Output("mkl_batch_mean: uint8" ) |
2635 | .Output("mkl_batch_variance: uint8" ) |
2636 | .Output("mkl_reserve_space_1: uint8" ) |
2637 | .Output("mkl_reserve_space_2: uint8" ) |
2638 | .Attr("T: numbertype" ) |
2639 | .Attr("epsilon: float = 0.0001" ) |
2640 | .Attr("data_format: string = 'NHWC'" ) |
2641 | .Attr("exponential_avg_factor: float = 1.0" ) |
2642 | .Attr("is_training: bool = true" ) |
2643 | .SetShapeFn(shape_inference::FusedBatchNormShape) |
2644 | .Doc(R"doc( |
2645 | oneDNN version of FusedBatchNorm operator. Uses oneDNN APIs to perform fused |
2646 | batch normalization. |
2647 | |
2648 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
2649 | expected to invoke these operators. |
2650 | )doc" ); |
2651 | |
2652 | REGISTER_OP("_MklFusedBatchNormGrad" ) |
2653 | .Input("y_backprop: T" ) |
2654 | .Input("x: T" ) |
2655 | .Input("scale: T" ) |
2656 | .Input("reserve_space_1: T" ) |
2657 | .Input("reserve_space_2: T" ) |
2658 | .Input("mkl_y_backprop: uint8" ) |
2659 | .Input("mkl_x: uint8" ) |
2660 | .Input("mkl_scale: uint8" ) |
2661 | .Input("mkl_reserve_space_1: uint8" ) |
2662 | .Input("mkl_reserve_space_2: uint8" ) |
2663 | .Output("x_backprop: T" ) |
2664 | .Output("scale_backprop: T" ) |
2665 | .Output("offset_backprop: T" ) |
2666 | .Output("reserve_space_3: T" ) |
2667 | .Output("reserve_space_4: T" ) |
2668 | .Output("mkl_x_backprop: uint8" ) |
2669 | .Output("mkl_scale_backprop: uint8" ) |
2670 | .Output("mkl_offset_backprop: uint8" ) |
2671 | .Output("mkl_reserve_space_3: uint8" ) |
2672 | .Output("mkl_reserve_space_4: uint8" ) |
2673 | .Attr("T: numbertype" ) |
2674 | .Attr("epsilon: float = 0.0001" ) |
2675 | .Attr("data_format: string = 'NHWC'" ) |
2676 | .Attr("is_training: bool = true" ) |
2677 | .SetShapeFn(shape_inference::FusedBatchNormGradShape) |
2678 | .Doc(R"doc( |
2679 | oneDNN version of FusedBatchNormGrad operator. Uses oneDNN APIs to compute |
2680 | gradients for fused batch normalization. |
2681 | |
2682 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
2683 | expected to invoke these operators. |
2684 | )doc" ); |
2685 | |
2686 | REGISTER_OP("_MklFusedBatchNormV2" ) |
2687 | .Input("x: T" ) |
2688 | .Input("scale: U" ) |
2689 | .Input("offset: U" ) |
2690 | .Input("mean: U" ) |
2691 | .Input("variance: U" ) |
2692 | .Input("mkl_x: uint8" ) |
2693 | .Input("mkl_scale: uint8" ) |
2694 | .Input("mkl_offset: uint8" ) |
2695 | .Input("mkl_mean: uint8" ) |
2696 | .Input("mkl_variance: uint8" ) |
2697 | .Output("y: T" ) |
2698 | .Output("batch_mean: U" ) |
2699 | .Output("batch_variance: U" ) |
2700 | .Output("reserve_space_1: U" ) |
2701 | .Output("reserve_space_2: U" ) |
2702 | .Output("mkl_y: uint8" ) |
2703 | .Output("mkl_batch_mean: uint8" ) |
2704 | .Output("mkl_batch_variance: uint8" ) |
2705 | .Output("mkl_reserve_space_1: uint8" ) |
2706 | .Output("mkl_reserve_space_2: uint8" ) |
2707 | .Attr("T: {bfloat16, float}" ) |
2708 | .Attr("U: {float}" ) |
2709 | .Attr("epsilon: float = 0.0001" ) |
2710 | .Attr(GetConvnetDataFormatAttrString()) |
2711 | .Attr("exponential_avg_factor: float = 1.0" ) |
2712 | .Attr("is_training: bool = true" ) |
2713 | .SetShapeFn(shape_inference::FusedBatchNormShape); |
2714 | |
2715 | REGISTER_OP("_MklFusedBatchNormGradV2" ) |
2716 | .Input("y_backprop: T" ) |
2717 | .Input("x: T" ) |
2718 | .Input("scale: float" ) |
2719 | .Input("reserve_space_1: U" ) |
2720 | .Input("reserve_space_2: U" ) |
2721 | .Input("mkl_y_backprop: uint8" ) |
2722 | .Input("mkl_x: uint8" ) |
2723 | .Input("mkl_scale: uint8" ) |
2724 | .Input("mkl_reserve_space_1: uint8" ) |
2725 | .Input("mkl_reserve_space_2: uint8" ) |
2726 | .Output("x_backprop: T" ) |
2727 | .Output("scale_backprop: U" ) |
2728 | .Output("offset_backprop: U" ) |
2729 | .Output("reserve_space_3: U" ) |
2730 | .Output("reserve_space_4: U" ) |
2731 | .Output("mkl_x_backprop: uint8" ) |
2732 | .Output("mkl_scale_backprop: uint8" ) |
2733 | .Output("mkl_offset_backprop: uint8" ) |
2734 | .Output("mkl_reserve_space_3: uint8" ) |
2735 | .Output("mkl_reserve_space_4: uint8" ) |
2736 | .Attr("T: {bfloat16, float}" ) |
2737 | .Attr("U: {float}" ) |
2738 | .Attr("epsilon: float = 0.0001" ) |
2739 | .Attr(GetConvnetDataFormatAttrString()) |
2740 | .Attr("is_training: bool = true" ) |
2741 | .SetShapeFn(shape_inference::FusedBatchNormGradShape); |
2742 | |
2743 | REGISTER_OP("_MklToTf" ) |
2744 | .Input("input: T" ) |
2745 | .Input("mkl_input: uint8" ) |
2746 | .Output("output: T" ) |
2747 | .Attr("T: {half, float, double, bfloat16, qint8, quint8, qint32}" ) |
2748 | .Attr(GetConvnetDataFormat2D3DAttrString()) |
2749 | .SetShapeFn(shape_inference::UnknownShape) |
2750 | .Doc(R"doc( |
2751 | MKL operator to convert a tensor from MKL layout to TensorFlow layout. |
2752 | |
2753 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2754 | expected to invoke these operators. |
2755 | )doc" ); |
2756 | |
2757 | REGISTER_OP("_MklInputConversion" ) |
2758 | .Input("input_0: T" ) |
2759 | .Input("input_1: T" ) |
2760 | .Input("mkl_input_0: uint8" ) |
2761 | .Input("mkl_input_1: uint8" ) |
2762 | .Output("output_0: T" ) |
2763 | .Output("output_1: T" ) |
2764 | .Output("mkl_output_0: uint8" ) |
2765 | .Output("mkl_output_1: uint8" ) |
2766 | // All datatypes supported by element-wise ops |
2767 | .Attr( |
2768 | "T: {half, float, bfloat16, double, uint8, int8, uint16, int16, int32, " |
2769 | "int64, complex64, complex128}" ) |
2770 | .Attr(GetConvnetDataFormat2D3DAttrString()) |
2771 | .SetShapeFn(shape_inference::UnknownShape) |
2772 | .Doc(R"doc( |
2773 | MKL operator to process the inputs to an elementwise MKL op. Both inputs |
2774 | need to be either in TF or in MKL format. This op is added before every |
2775 | element-wise MKL op. |
2776 | |
2777 | NOTE Do not invoke this operator directly in Python. Graph rewrite pass is |
2778 | expected to invoke these operators. |
2779 | )doc" ); |
2780 | |
2781 | #endif // INTEL_MKL |
2782 | REGISTER_OP("QuantizedConv2DAndRequantize" ) |
2783 | .Input("input: Tinput" ) |
2784 | .Input("filter: Tfilter" ) |
2785 | .Input("min_input: float" ) |
2786 | .Input("max_input: float" ) |
2787 | .Input("min_filter: float" ) |
2788 | .Input("max_filter: float" ) |
2789 | .Input("min_freezed_output: float" ) |
2790 | .Input("max_freezed_output: float" ) |
2791 | .Output("output: out_type" ) |
2792 | .Output("min_output: float" ) |
2793 | .Output("max_output: float" ) |
2794 | .Attr("Tinput: quantizedtype" ) |
2795 | .Attr("Tfilter: quantizedtype" ) |
2796 | .Attr("out_type: quantizedtype = DT_QINT8" ) |
2797 | .Attr("strides: list(int)" ) |
2798 | .Attr(GetPaddingAttrString()) |
2799 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2800 | .Attr("padding_list: list(int) = []" ) |
2801 | .SetShapeFn([](InferenceContext* c) { |
2802 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
2803 | ShapeHandle unused; |
2804 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
2805 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2806 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
2807 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
2808 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
2809 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
2810 | c->set_output(1, c->Scalar()); |
2811 | c->set_output(2, c->Scalar()); |
2812 | return OkStatus(); |
2813 | }); |
2814 | |
2815 | // Fusion of Quantized Conv2D and BiasAdd. |
2816 | REGISTER_OP("QuantizedConv2DWithBias" ) |
2817 | .Input("input: Tinput" ) |
2818 | .Input("filter: Tfilter" ) |
2819 | .Input("bias: float" ) |
2820 | .Input("min_input: float" ) |
2821 | .Input("max_input: float" ) |
2822 | .Input("min_filter: float" ) |
2823 | .Input("max_filter: float" ) |
2824 | .Output("output: out_type" ) |
2825 | .Output("min_output: float" ) |
2826 | .Output("max_output: float" ) |
2827 | .Attr("Tinput: quantizedtype" ) |
2828 | .Attr("Tfilter: quantizedtype" ) |
2829 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
2830 | .Attr("strides: list(int)" ) |
2831 | .Attr(GetPaddingAttrString()) |
2832 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2833 | .Attr("padding_list: list(int) = []" ) |
2834 | .SetShapeFn([](InferenceContext* c) { |
2835 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
2836 | ShapeHandle unused, channel; |
2837 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
2838 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2839 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
2840 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
2841 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
2842 | c->set_output(1, channel); |
2843 | c->set_output(2, channel); |
2844 | return OkStatus(); |
2845 | }); |
2846 | |
2847 | REGISTER_OP("QuantizedConv2DWithBiasAndRequantize" ) |
2848 | .Input("input: Tinput" ) |
2849 | .Input("filter: Tfilter" ) |
2850 | .Input("bias: Tbias" ) |
2851 | .Input("min_input: float" ) |
2852 | .Input("max_input: float" ) |
2853 | .Input("min_filter: float" ) |
2854 | .Input("max_filter: float" ) |
2855 | .Input("min_freezed_output: float" ) |
2856 | .Input("max_freezed_output: float" ) |
2857 | .Output("output: out_type" ) |
2858 | .Output("min_output: float" ) |
2859 | .Output("max_output: float" ) |
2860 | .Attr("Tinput: quantizedtype" ) |
2861 | .Attr("Tfilter: quantizedtype" ) |
2862 | .Attr("Tbias: {float, qint32}" ) |
2863 | .Attr("out_type: quantizedtype = DT_QINT8" ) |
2864 | .Attr("strides: list(int)" ) |
2865 | .Attr(GetPaddingAttrString()) |
2866 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2867 | .Attr("padding_list: list(int) = []" ) |
2868 | .SetShapeFn([](InferenceContext* c) { |
2869 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
2870 | ShapeHandle unused, channel; |
2871 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
2872 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2873 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
2874 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
2875 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
2876 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
2877 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
2878 | c->set_output(1, c->Scalar()); |
2879 | c->set_output(2, c->Scalar()); |
2880 | return OkStatus(); |
2881 | }); |
2882 | |
2883 | // Fusion of Quantized Conv2D and Relu. |
2884 | REGISTER_OP("QuantizedConv2DAndRelu" ) |
2885 | .Input("input: Tinput" ) |
2886 | .Input("filter: Tfilter" ) |
2887 | .Input("min_input: float" ) |
2888 | .Input("max_input: float" ) |
2889 | .Input("min_filter: float" ) |
2890 | .Input("max_filter: float" ) |
2891 | .Output("output: out_type" ) |
2892 | .Output("min_output: float" ) |
2893 | .Output("max_output: float" ) |
2894 | .Attr("Tinput: quantizedtype" ) |
2895 | .Attr("Tfilter: quantizedtype" ) |
2896 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
2897 | .Attr("strides: list(int)" ) |
2898 | .Attr(GetPaddingAttrString()) |
2899 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2900 | .Attr("padding_list: list(int) = []" ) |
2901 | .SetShapeFn([](InferenceContext* c) { |
2902 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
2903 | ShapeHandle unused, channel; |
2904 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
2905 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2906 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); |
2907 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
2908 | c->set_output(1, channel); |
2909 | c->set_output(2, channel); |
2910 | return OkStatus(); |
2911 | }); |
2912 | |
2913 | REGISTER_OP("QuantizedConv2DAndReluAndRequantize" ) |
2914 | .Input("input: Tinput" ) |
2915 | .Input("filter: Tfilter" ) |
2916 | .Input("min_input: float" ) |
2917 | .Input("max_input: float" ) |
2918 | .Input("min_filter: float" ) |
2919 | .Input("max_filter: float" ) |
2920 | .Input("min_freezed_output: float" ) |
2921 | .Input("max_freezed_output: float" ) |
2922 | .Output("output: out_type" ) |
2923 | .Output("min_output: float" ) |
2924 | .Output("max_output: float" ) |
2925 | .Attr("Tinput: quantizedtype" ) |
2926 | .Attr("Tfilter: quantizedtype" ) |
2927 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
2928 | .Attr("strides: list(int)" ) |
2929 | .Attr(GetPaddingAttrString()) |
2930 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2931 | .Attr("padding_list: list(int) = []" ) |
2932 | .SetShapeFn([](InferenceContext* c) { |
2933 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
2934 | ShapeHandle unused, channel; |
2935 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
2936 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2937 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); |
2938 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
2939 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
2940 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
2941 | c->set_output(1, c->Scalar()); |
2942 | c->set_output(2, c->Scalar()); |
2943 | return OkStatus(); |
2944 | }); |
2945 | |
2946 | // Fusion of Quantized Conv2D, BiasAdd and Relu. |
2947 | REGISTER_OP("QuantizedConv2DWithBiasAndRelu" ) |
2948 | .Input("input: Tinput" ) |
2949 | .Input("filter: Tfilter" ) |
2950 | .Input("bias: float" ) |
2951 | .Input("min_input: float" ) |
2952 | .Input("max_input: float" ) |
2953 | .Input("min_filter: float" ) |
2954 | .Input("max_filter: float" ) |
2955 | .Output("output: out_type" ) |
2956 | .Output("min_output: float" ) |
2957 | .Output("max_output: float" ) |
2958 | .Attr("Tinput: quantizedtype" ) |
2959 | .Attr("Tfilter: quantizedtype" ) |
2960 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
2961 | .Attr("strides: list(int)" ) |
2962 | .Attr(GetPaddingAttrString()) |
2963 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2964 | .Attr("padding_list: list(int) = []" ) |
2965 | .SetShapeFn([](InferenceContext* c) { |
2966 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
2967 | ShapeHandle unused, channel; |
2968 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
2969 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2970 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
2971 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
2972 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
2973 | c->set_output(1, channel); |
2974 | c->set_output(2, channel); |
2975 | return OkStatus(); |
2976 | }); |
2977 | |
2978 | // Fusion of Quantized Conv2D, BiasAdd, Relu, and Requantize. |
2979 | REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize" ) |
2980 | .Input("input: Tinput" ) |
2981 | .Input("filter: Tfilter" ) |
2982 | .Input("bias: Tbias" ) |
2983 | .Input("min_input: float" ) |
2984 | .Input("max_input: float" ) |
2985 | .Input("min_filter: float" ) |
2986 | .Input("max_filter: float" ) |
2987 | .Input("min_freezed_output: float" ) |
2988 | .Input("max_freezed_output: float" ) |
2989 | .Output("output: out_type" ) |
2990 | .Output("min_output: float" ) |
2991 | .Output("max_output: float" ) |
2992 | .Attr("Tinput: quantizedtype" ) |
2993 | .Attr("Tfilter: quantizedtype" ) |
2994 | .Attr("Tbias: {float, qint32}" ) |
2995 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
2996 | .Attr("strides: list(int)" ) |
2997 | .Attr(GetPaddingAttrString()) |
2998 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
2999 | .Attr("padding_list: list(int) = []" ) |
3000 | .SetShapeFn([](InferenceContext* c) { |
3001 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
3002 | ShapeHandle unused, channel; |
3003 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3004 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3005 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3006 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
3007 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
3008 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
3009 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
3010 | c->set_output(1, c->Scalar()); |
3011 | c->set_output(2, c->Scalar()); |
3012 | return OkStatus(); |
3013 | }); |
3014 | |
3015 | // Fusion of Quantized Conv2D, BiasAdd, Sum, and Relu. |
3016 | REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu" ) |
3017 | .Input("input: Tinput" ) |
3018 | .Input("filter: Tfilter" ) |
3019 | .Input("bias: float" ) |
3020 | .Input("min_input: float" ) |
3021 | .Input("max_input: float" ) |
3022 | .Input("min_filter: float" ) |
3023 | .Input("max_filter: float" ) |
3024 | .Input("summand: float" ) |
3025 | .Output("output: out_type" ) |
3026 | .Output("min_output: float" ) |
3027 | .Output("max_output: float" ) |
3028 | .Attr("Tinput: quantizedtype" ) |
3029 | .Attr("Tfilter: quantizedtype" ) |
3030 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
3031 | .Attr("strides: list(int)" ) |
3032 | .Attr(GetPaddingAttrString()) |
3033 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
3034 | .Attr("padding_list: list(int) = []" ) |
3035 | .SetShapeFn([](InferenceContext* c) { |
3036 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
3037 | ShapeHandle unused, channel; |
3038 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3039 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3040 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3041 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
3042 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
3043 | c->set_output(1, channel); |
3044 | c->set_output(2, channel); |
3045 | return OkStatus(); |
3046 | }); |
3047 | |
3048 | REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize" ) |
3049 | .Input("input: Tinput" ) |
3050 | .Input("filter: Tfilter" ) |
3051 | .Input("bias: Tbias" ) |
3052 | .Input("min_input: float" ) |
3053 | .Input("max_input: float" ) |
3054 | .Input("min_filter: float" ) |
3055 | .Input("max_filter: float" ) |
3056 | .Input("min_freezed_output: float" ) |
3057 | .Input("max_freezed_output: float" ) |
3058 | .Input("summand: Tsummand" ) |
3059 | .Input("min_summand: float" ) |
3060 | .Input("max_summand: float" ) |
3061 | .Output("output: out_type" ) |
3062 | .Output("min_output: float" ) |
3063 | .Output("max_output: float" ) |
3064 | .Attr("Tinput: quantizedtype" ) |
3065 | .Attr("Tfilter: quantizedtype" ) |
3066 | .Attr("Tbias: {float, qint32}" ) |
3067 | .Attr("Tsummand: quantizedtype" ) |
3068 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
3069 | .Attr("strides: list(int)" ) |
3070 | .Attr(GetPaddingAttrString()) |
3071 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
3072 | .Attr("padding_list: list(int) = []" ) |
3073 | .SetShapeFn([](InferenceContext* c) { |
3074 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
3075 | ShapeHandle unused, channel; |
3076 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3077 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3078 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3079 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
3080 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
3081 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
3082 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
3083 | c->set_output(1, c->Scalar()); |
3084 | c->set_output(2, c->Scalar()); |
3085 | return OkStatus(); |
3086 | }); |
3087 | |
3088 | REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize" ) |
3089 | .Input("input: Tinput" ) |
3090 | .Input("filter: Tfilter" ) |
3091 | .Input("bias: Tbias" ) |
3092 | .Input("min_input: float" ) |
3093 | .Input("max_input: float" ) |
3094 | .Input("min_filter: float" ) |
3095 | .Input("max_filter: float" ) |
3096 | .Input("min_freezed_output: float" ) |
3097 | .Input("max_freezed_output: float" ) |
3098 | .Input("summand: Tsummand" ) |
3099 | .Input("min_summand: float" ) |
3100 | .Input("max_summand: float" ) |
3101 | .Output("output: out_type" ) |
3102 | .Output("min_output: float" ) |
3103 | .Output("max_output: float" ) |
3104 | .Attr("Tinput: quantizedtype" ) |
3105 | .Attr("Tfilter: quantizedtype" ) |
3106 | .Attr("Tbias: {float, qint32}" ) |
3107 | .Attr("Tsummand: quantizedtype" ) |
3108 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
3109 | .Attr("strides: list(int)" ) |
3110 | .Attr(GetPaddingAttrString()) |
3111 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
3112 | .Attr("padding_list: list(int) = []" ) |
3113 | .SetShapeFn([](InferenceContext* c) { |
3114 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
3115 | ShapeHandle unused, channel; |
3116 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3117 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3118 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3119 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
3120 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(6), 1, &channel)); |
3121 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
3122 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
3123 | // Since activations are not requantized per channel, `min_output` |
3124 | // and `max_output` are scalars. |
3125 | c->set_output(1, c->Scalar()); |
3126 | c->set_output(2, c->Scalar()); |
3127 | return OkStatus(); |
3128 | }); |
3129 | |
3130 | // Fusion of Quantized MatMul and BiasAdd. |
3131 | REGISTER_OP("QuantizedMatMulWithBias" ) |
3132 | .Input("a: T1" ) |
3133 | .Input("b: T2" ) |
3134 | .Input("bias: Tbias" ) |
3135 | .Input("min_a: float" ) |
3136 | .Input("max_a: float" ) |
3137 | .Input("min_b: float" ) |
3138 | .Input("max_b: float" ) |
3139 | .Output("out: Toutput" ) |
3140 | .Output("min_out: float" ) |
3141 | .Output("max_out: float" ) |
3142 | .Attr("T1: quantizedtype" ) |
3143 | .Attr("T2: quantizedtype" ) |
3144 | .Attr("Tbias: {float, qint32}" ) |
3145 | .Attr("Toutput: quantizedtype = DT_QINT32" ) |
3146 | .Attr("transpose_a: bool = false" ) |
3147 | .Attr("transpose_b: bool = false" ) |
3148 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
3149 | .SetShapeFn([](InferenceContext* c) { |
3150 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
3151 | ShapeHandle unused; |
3152 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3153 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3154 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3155 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
3156 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
3157 | c->set_output(1, c->Scalar()); |
3158 | c->set_output(2, c->Scalar()); |
3159 | return OkStatus(); |
3160 | }); |
3161 | |
3162 | REGISTER_OP("QuantizedMatMulWithBiasAndRelu" ) |
3163 | .Input("a: T1" ) |
3164 | .Input("b: T2" ) |
3165 | .Input("bias: float" ) |
3166 | .Input("min_a: float" ) |
3167 | .Input("max_a: float" ) |
3168 | .Input("min_b: float" ) |
3169 | .Input("max_b: float" ) |
3170 | .Output("out: Toutput" ) |
3171 | .Output("min_out: float" ) |
3172 | .Output("max_out: float" ) |
3173 | .Attr("T1: quantizedtype" ) |
3174 | .Attr("T2: quantizedtype" ) |
3175 | .Attr("Toutput: quantizedtype = DT_QINT32" ) |
3176 | .Attr("transpose_a: bool = false" ) |
3177 | .Attr("transpose_b: bool = false" ) |
3178 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
3179 | .SetShapeFn([](InferenceContext* c) { |
3180 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
3181 | ShapeHandle unused; |
3182 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3183 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3184 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3185 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
3186 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
3187 | c->set_output(1, c->Scalar()); |
3188 | c->set_output(2, c->Scalar()); |
3189 | return OkStatus(); |
3190 | }); |
3191 | |
3192 | REGISTER_OP("QuantizedMatMulWithBiasAndReluAndRequantize" ) |
3193 | .Input("a: T1" ) |
3194 | .Input("b: T2" ) |
3195 | .Input("bias: Tbias" ) |
3196 | .Input("min_a: float" ) |
3197 | .Input("max_a: float" ) |
3198 | .Input("min_b: float" ) |
3199 | .Input("max_b: float" ) |
3200 | .Input("min_freezed_output: float" ) |
3201 | .Input("max_freezed_output: float" ) |
3202 | .Output("out: Toutput" ) |
3203 | .Output("min_out: float" ) |
3204 | .Output("max_out: float" ) |
3205 | .Attr("T1: quantizedtype" ) |
3206 | .Attr("T2: quantizedtype" ) |
3207 | .Attr("Tbias: {float, qint32}" ) |
3208 | .Attr("Toutput: quantizedtype = DT_QUINT8" ) |
3209 | .Attr("transpose_a: bool = false" ) |
3210 | .Attr("transpose_b: bool = false" ) |
3211 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
3212 | .SetShapeFn([](InferenceContext* c) { |
3213 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
3214 | ShapeHandle unused; |
3215 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3216 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3217 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3218 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
3219 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
3220 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
3221 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
3222 | c->set_output(1, c->Scalar()); |
3223 | c->set_output(2, c->Scalar()); |
3224 | return OkStatus(); |
3225 | }); |
3226 | |
3227 | REGISTER_OP("QuantizedMatMulWithBiasAndDequantize" ) |
3228 | .Input("a: T1" ) |
3229 | .Input("b: T2" ) |
3230 | .Input("bias: Tbias" ) |
3231 | .Input("min_a: float" ) |
3232 | .Input("max_a: float" ) |
3233 | .Input("min_b: float" ) |
3234 | .Input("max_b: float" ) |
3235 | .Input("min_freezed_output: float" ) |
3236 | .Input("max_freezed_output: float" ) |
3237 | .Output("out: Toutput" ) |
3238 | .Attr("T1: quantizedtype" ) |
3239 | .Attr("T2: quantizedtype" ) |
3240 | .Attr("Tbias: {float, qint32}" ) |
3241 | .Attr("Toutput: {float}" ) |
3242 | .Attr("transpose_a: bool = false" ) |
3243 | .Attr("transpose_b: bool = false" ) |
3244 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
3245 | .SetShapeFn([](InferenceContext* c) { |
3246 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
3247 | ShapeHandle unused; |
3248 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3249 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3250 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3251 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
3252 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
3253 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
3254 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
3255 | |
3256 | return OkStatus(); |
3257 | }); |
3258 | |
3259 | REGISTER_OP("QuantizedMatMulWithBiasAndRequantize" ) |
3260 | .Input("a: T1" ) |
3261 | .Input("b: T2" ) |
3262 | .Input("bias: Tbias" ) |
3263 | .Input("min_a: float" ) |
3264 | .Input("max_a: float" ) |
3265 | .Input("min_b: float" ) |
3266 | .Input("max_b: float" ) |
3267 | .Input("min_freezed_output: float" ) |
3268 | .Input("max_freezed_output: float" ) |
3269 | .Output("out: Toutput" ) |
3270 | .Output("min_out: float" ) |
3271 | .Output("max_out: float" ) |
3272 | .Attr("T1: quantizedtype" ) |
3273 | .Attr("T2: quantizedtype" ) |
3274 | .Attr("Tbias: {float, qint32}" ) |
3275 | .Attr("Toutput: quantizedtype = DT_QUINT8" ) |
3276 | .Attr("transpose_a: bool = false" ) |
3277 | .Attr("transpose_b: bool = false" ) |
3278 | .Attr("input_quant_mode: {'MIN_FIRST', 'SCALED'} = 'MIN_FIRST'" ) |
3279 | .SetShapeFn([](InferenceContext* c) { |
3280 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
3281 | ShapeHandle unused; |
3282 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
3283 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3284 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
3285 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
3286 | TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); |
3287 | TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused)); |
3288 | TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused)); |
3289 | c->set_output(1, c->Scalar()); |
3290 | c->set_output(2, c->Scalar()); |
3291 | return OkStatus(); |
3292 | }); |
3293 | |
3294 | REGISTER_OP("QuantizedConv2DPerChannel" ) |
3295 | .Input("input: Tinput" ) |
3296 | .Input("filter: Tfilter" ) |
3297 | .Input("min_input: float" ) |
3298 | .Input("max_input: float" ) |
3299 | .Input("min_filter: float" ) |
3300 | .Input("max_filter: float" ) |
3301 | .Output("output: out_type" ) |
3302 | .Output("min_output: float" ) |
3303 | .Output("max_output: float" ) |
3304 | .Attr("Tinput: quantizedtype" ) |
3305 | .Attr("Tfilter: quantizedtype" ) |
3306 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
3307 | .Attr("strides: list(int)" ) |
3308 | .Attr(GetPaddingAttrString()) |
3309 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
3310 | .SetShapeFn([](InferenceContext* c) { |
3311 | TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c)); |
3312 | ShapeHandle unused, channel; |
3313 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
3314 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
3315 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(4), 1, &channel)); |
3316 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(5), 1, &channel)); |
3317 | c->set_output(1, channel); |
3318 | c->set_output(2, channel); |
3319 | return OkStatus(); |
3320 | }); |
3321 | |
3322 | REGISTER_OP("QuantizedDepthwiseConv2D" ) |
3323 | .Input("input: Tinput" ) |
3324 | .Input("filter: Tfilter" ) |
3325 | .Input("min_input: float" ) |
3326 | .Input("max_input: float" ) |
3327 | .Input("min_filter: float" ) |
3328 | .Input("max_filter: float" ) |
3329 | .Output("output: out_type" ) |
3330 | .Output("min_output: float" ) |
3331 | .Output("max_output: float" ) |
3332 | .Attr("Tinput: quantizedtype" ) |
3333 | .Attr("Tfilter: quantizedtype" ) |
3334 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
3335 | .Attr("strides: list(int)" ) |
3336 | .Attr(GetPaddingAttrString()) |
3337 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
3338 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); |
3339 | |
3340 | REGISTER_OP("QuantizedDepthwiseConv2DWithBias" ) |
3341 | .Input("input: Tinput" ) |
3342 | .Input("filter: Tfilter" ) |
3343 | .Input("bias: float" ) |
3344 | .Input("min_input: float" ) |
3345 | .Input("max_input: float" ) |
3346 | .Input("min_filter: float" ) |
3347 | .Input("max_filter: float" ) |
3348 | .Output("output: out_type" ) |
3349 | .Output("min_output: float" ) |
3350 | .Output("max_output: float" ) |
3351 | .Attr("Tinput: quantizedtype" ) |
3352 | .Attr("Tfilter: quantizedtype" ) |
3353 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
3354 | .Attr("strides: list(int)" ) |
3355 | .Attr(GetPaddingAttrString()) |
3356 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
3357 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); |
3358 | |
3359 | REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndRelu" ) |
3360 | .Input("input: Tinput" ) |
3361 | .Input("filter: Tfilter" ) |
3362 | .Input("bias: float" ) |
3363 | .Input("min_input: float" ) |
3364 | .Input("max_input: float" ) |
3365 | .Input("min_filter: float" ) |
3366 | .Input("max_filter: float" ) |
3367 | .Output("output: out_type" ) |
3368 | .Output("min_output: float" ) |
3369 | .Output("max_output: float" ) |
3370 | .Attr("Tinput: quantizedtype" ) |
3371 | .Attr("Tfilter: quantizedtype" ) |
3372 | .Attr("out_type: quantizedtype = DT_QINT32" ) |
3373 | .Attr("strides: list(int)" ) |
3374 | .Attr(GetPaddingAttrString()) |
3375 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
3376 | .Attr("padding_list: list(int) = []" ) |
3377 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); |
3378 | |
3379 | REGISTER_OP("QuantizedDepthwiseConv2DWithBiasAndReluAndRequantize" ) |
3380 | .Input("input: Tinput" ) |
3381 | .Input("filter: Tfilter" ) |
3382 | .Input("bias: Tbias" ) |
3383 | .Input("min_input: float" ) |
3384 | .Input("max_input: float" ) |
3385 | .Input("min_filter: float" ) |
3386 | .Input("max_filter: float" ) |
3387 | .Input("min_freezed_output: float" ) |
3388 | .Input("max_freezed_output: float" ) |
3389 | .Output("output: out_type" ) |
3390 | .Output("min_output: float" ) |
3391 | .Output("max_output: float" ) |
3392 | .Attr("Tinput: quantizedtype" ) |
3393 | .Attr("Tfilter: quantizedtype" ) |
3394 | .Attr("Tbias: {float, qint32}" ) |
3395 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
3396 | .Attr("strides: list(int)" ) |
3397 | .Attr(GetPaddingAttrString()) |
3398 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ) |
3399 | .Attr("padding_list: list(int) = []" ) |
3400 | .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape); |
3401 | |
3402 | REGISTER_OP("IsotonicRegression" ) |
3403 | .Input("input: T" ) |
3404 | .Output("output: output_dtype" ) |
3405 | .Output("segments: int32" ) |
3406 | .Attr("T: realnumbertype" ) |
3407 | .Attr("output_dtype: {half, bfloat16, float, double} = DT_FLOAT" ) |
3408 | .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* context) { |
3409 | context->set_output(0, context->input(0)); |
3410 | context->set_output(1, context->input(0)); |
3411 | return OkStatus(); |
3412 | }); |
3413 | |
3414 | } // namespace tensorflow |
3415 | |