1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/numeric_op.h" |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/framework/shape_inference.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | using shape_inference::DimensionHandle; |
24 | using shape_inference::InferenceContext; |
25 | using shape_inference::ShapeHandle; |
26 | |
27 | REGISTER_OP("AddN" ) |
28 | .Input("inputs: N * T" ) |
29 | .Output("sum: T" ) |
30 | .Attr("N: int >= 1" ) |
31 | .Attr("T: {numbertype, variant}" ) |
32 | .SetIsCommutative() |
33 | .SetIsAggregate() |
34 | .SetShapeFn([](InferenceContext* c) { |
35 | ShapeHandle cur = c->input(c->num_inputs() - 1); |
36 | for (int i = c->num_inputs() - 2; i >= 0; --i) { |
37 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), |
38 | "From merging shape " , i, |
39 | " with other shapes." ); |
40 | } |
41 | c->set_output(0, cur); |
42 | |
43 | DataType dtype; |
44 | TF_RETURN_IF_ERROR(c->GetAttr("T" , &dtype)); |
45 | |
46 | if (dtype != DT_VARIANT) { |
47 | // Exit early if not DT_VARIANT. |
48 | return OkStatus(); |
49 | } else { |
50 | // DT_VARIANT shape handle shape inference. All sizes and dtypes must |
51 | // be the same; all shapes must be compatible via Merge. |
52 | std::vector<shape_inference::ShapeAndType> cur_shapes_and_types; |
53 | auto* shapes_and_types = |
54 | c->input_handle_shapes_and_types(c->num_inputs() - 1); |
55 | if (shapes_and_types) { |
56 | cur_shapes_and_types = *shapes_and_types; |
57 | } |
58 | |
59 | for (int i = c->num_inputs() - 2; i >= 0; --i) { |
60 | auto shapes_and_types_i = c->input_handle_shapes_and_types(i); |
61 | if (!shapes_and_types && shapes_and_types_i) { |
62 | // TODO(ebrevdo): Find cases where this happens and fix their shape |
63 | // inference. If we are calling AddN on variant types, they should |
64 | // all have consistent shape_and_type info. |
65 | shapes_and_types = shapes_and_types_i; |
66 | } else if (shapes_and_types && shapes_and_types_i) { |
67 | if (shapes_and_types_i->size() != shapes_and_types->size()) { |
68 | return errors::InvalidArgument( |
69 | "shapes_and_types[" , i, |
70 | "].size() == " , shapes_and_types_i->size(), |
71 | " != shapes_and_types[0].size() == " , |
72 | shapes_and_types->size()); |
73 | } |
74 | for (int j = 0; j < shapes_and_types->size(); ++j) { |
75 | if (shapes_and_types->at(j).dtype != |
76 | shapes_and_types_i->at(j).dtype) { |
77 | return errors::InvalidArgument( |
78 | "shapes_and_types[" , i, "][" , j, "].dtype() == " , |
79 | DataTypeString(shapes_and_types_i->at(j).dtype), |
80 | " != shapes_and_types[0][" , j, "].dtype == " , |
81 | DataTypeString(shapes_and_types->at(j).dtype)); |
82 | } |
83 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
84 | c->Merge(shapes_and_types_i->at(j).shape, |
85 | cur_shapes_and_types.at(j).shape, |
86 | &cur_shapes_and_types.at(j).shape), |
87 | "From merging shapes_and_types[" , i, "][" , j, "].shape with " , |
88 | "shapes_and_types[0][" , j, "].shape" ); |
89 | } |
90 | } |
91 | } |
92 | if (shapes_and_types) { |
93 | c->set_output_handle_shapes_and_types(0, cur_shapes_and_types); |
94 | } |
95 | return OkStatus(); |
96 | } |
97 | }); |
98 | |
99 | // -------------------------------------------------------------------------- |
100 | |
101 | // Note that the following operator is just a placeholder and has no |
102 | // associated kernel. The code in accumulate_n_optimizer.cc replaces |
103 | // this placeholder with a graph of operators that do have kernels. |
104 | // The Python code that generates instances of this op is currently in |
105 | // contrib/framework/python/ops/accumulate_n_v2.py |
106 | REGISTER_OP("AccumulateNV2" ) |
107 | .Input("inputs: N * T" ) |
108 | .Output("sum: T" ) |
109 | .Attr("N: int >= 1" ) |
110 | .Attr("T: numbertype" ) |
111 | .Attr("shape: shape" ) |
112 | .SetIsCommutative() |
113 | .SetIsAggregate() |
114 | .SetShapeFn(shape_inference::ExplicitShape); |
115 | |
116 | // -------------------------------------------------------------------------- |
117 | |
118 | REGISTER_OP("BatchMatMul" ) |
119 | .Input("x: T" ) |
120 | .Input("y: T" ) |
121 | .Output("output: T" ) |
122 | .Attr( |
123 | "T: {bfloat16, half, float, double, int32, int64, complex64, " |
124 | "complex128}" ) |
125 | .Attr("adj_x: bool = false" ) |
126 | .Attr("adj_y: bool = false" ) |
127 | .SetShapeFn(shape_inference::BatchMatMulShape); |
128 | |
129 | REGISTER_OP("BatchMatMulV2" ) |
130 | .Input("x: T" ) |
131 | .Input("y: T" ) |
132 | .Output("output: T" ) |
133 | .Attr( |
134 | "T: {bfloat16, half, float, double, int16, int32, int64, complex64, " |
135 | "complex128}" ) |
136 | .Attr("adj_x: bool = false" ) |
137 | .Attr("adj_y: bool = false" ) |
138 | .SetShapeFn(shape_inference::BatchMatMulV2Shape); |
139 | |
140 | REGISTER_OP("BatchMatMulV3" ) |
141 | .Input("x: Ta" ) |
142 | .Input("y: Tb" ) |
143 | .Output("output: Tout" ) |
144 | .Attr( |
145 | "Ta: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " |
146 | "complex64, complex128}" ) |
147 | .Attr( |
148 | "Tb: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " |
149 | "complex64, complex128}" ) |
150 | .Attr( |
151 | "Tout: {bfloat16, half, float, double, int16, int32, int64, complex64, " |
152 | "complex128}" ) |
153 | .Attr("adj_x: bool = false" ) |
154 | .Attr("adj_y: bool = false" ) |
155 | .SetShapeFn(shape_inference::BatchMatMulV2Shape); |
156 | |
157 | #ifdef INTEL_MKL |
158 | REGISTER_OP("_MklBatchMatMul" ) |
159 | .Input("x: T" ) |
160 | .Input("y: T" ) |
161 | .Output("output: T" ) |
162 | .Attr("T: {bfloat16, float}" ) |
163 | .Attr("adj_x: bool = false" ) |
164 | .Attr("adj_y: bool = false" ) |
165 | .SetShapeFn(shape_inference::BatchMatMulShape); |
166 | |
167 | REGISTER_OP("_MklBatchMatMulV2" ) |
168 | .Input("x: T" ) |
169 | .Input("y: T" ) |
170 | .Output("output: T" ) |
171 | .Attr("T: {bfloat16, float}" ) |
172 | .Attr("adj_x: bool = false" ) |
173 | .Attr("adj_y: bool = false" ) |
174 | .SetShapeFn(shape_inference::BatchMatMulV2Shape); |
175 | #endif // INTEL_MKL |
176 | |
177 | // -------------------------------------------------------------------------- |
178 | // Casting Ops |
179 | // |
180 | // NOTE: Only a smaller number of types are supported by |
181 | // Cast. The exact casting rule is TBD. The current |
182 | // implementation uses C++ static cast rules for numeric |
183 | // types, which may be changed in the future. |
184 | REGISTER_OP("Cast" ) |
185 | .Input("x: SrcT" ) |
186 | .Output("y: DstT" ) |
187 | .Attr("SrcT: type" ) |
188 | .Attr("DstT: type" ) |
189 | .Attr("Truncate: bool = false" ) |
190 | .SetTypeConstructor(full_type::NoOp()) |
191 | .SetForwardTypeFn(full_type::KeepExisting()) |
192 | .SetShapeFn(shape_inference::UnchangedShape); |
193 | |
194 | REGISTER_OP("_HostCast" ) |
195 | .Input("x: SrcT" ) |
196 | .Output("y: DstT" ) |
197 | .Attr("SrcT: type" ) |
198 | .Attr("DstT: type" ) |
199 | .Attr("Truncate: bool = false" ) |
200 | .SetTypeConstructor(full_type::NoOp()) |
201 | .SetForwardTypeFn(full_type::KeepExisting()) |
202 | .SetShapeFn(shape_inference::UnchangedShape) |
203 | .Doc(R"doc( |
204 | Cast x of type SrcT to y of DstT. |
205 | |
206 | _HostCast requires its input and produces its output in host memory. |
207 | )doc" ); |
208 | |
209 | // -------------------------------------------------------------------------- |
210 | |
211 | REGISTER_OP("Abs" ) |
212 | .Input("x: T" ) |
213 | .Output("y: T" ) |
214 | .Attr("T: {bfloat16, half, float, double, int8, int16, int32, int64}" ) |
215 | .SetShapeFn(shape_inference::UnchangedShape); |
216 | |
217 | REGISTER_OP("ComplexAbs" ) |
218 | .Input("x: T" ) |
219 | .Output("y: Tout" ) |
220 | .Attr("T: {complex64, complex128} = DT_COMPLEX64" ) |
221 | .Attr("Tout: {float, double} = DT_FLOAT" ) |
222 | .SetShapeFn(shape_inference::UnchangedShape); |
223 | |
224 | // Declares cwise unary operations signature: 't -> 't |
225 | #define UNARY() \ |
226 | Input("x: T") \ |
227 | .Output("y: T") \ |
228 | .Attr( \ |
229 | "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \ |
230 | "complex64, complex128}") \ |
231 | .SetShapeFn(shape_inference::UnchangedShape) |
232 | |
233 | #define UNARY_UNSIGNED() \ |
234 | Input("x: T") \ |
235 | .Output("y: T") \ |
236 | .Attr( \ |
237 | "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \ |
238 | "uint8, uint16, uint32, uint64, complex64, complex128}") \ |
239 | .SetShapeFn(shape_inference::UnchangedShape) |
240 | |
241 | #define UNARY_REAL() \ |
242 | Input("x: T") \ |
243 | .Output("y: T") \ |
244 | .Attr("T: {bfloat16, half, float, double}") \ |
245 | .SetShapeFn(shape_inference::UnchangedShape) |
246 | |
247 | #define UNARY_COMPLEX() \ |
248 | Input("x: T") \ |
249 | .Output("y: T") \ |
250 | .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \ |
251 | .SetShapeFn(shape_inference::UnchangedShape) |
252 | |
253 | #define UNARY_GRADIENT_COMPLEX() \ |
254 | Input("y: T") \ |
255 | .Input("dy: T") \ |
256 | .Output("z: T") \ |
257 | .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \ |
258 | .SetShapeFn(shape_inference::UnchangedShape) |
259 | |
260 | REGISTER_OP("Neg" ).UNARY(); |
261 | |
262 | REGISTER_OP("Inv" ).UNARY(); |
263 | |
264 | REGISTER_OP("InvGrad" ).UNARY_GRADIENT_COMPLEX(); |
265 | |
266 | REGISTER_OP("Reciprocal" ).UNARY(); |
267 | |
268 | REGISTER_OP("ReciprocalGrad" ).UNARY_GRADIENT_COMPLEX(); |
269 | |
270 | REGISTER_OP("Square" ).UNARY_UNSIGNED(); |
271 | |
272 | REGISTER_OP("Sqrt" ).UNARY_COMPLEX(); |
273 | |
274 | REGISTER_OP("SqrtGrad" ).UNARY_GRADIENT_COMPLEX(); |
275 | |
276 | REGISTER_OP("Rsqrt" ).UNARY_COMPLEX(); |
277 | |
278 | REGISTER_OP("Round" ).UNARY(); |
279 | |
280 | REGISTER_OP("RsqrtGrad" ).UNARY_GRADIENT_COMPLEX(); |
281 | |
282 | REGISTER_OP("Exp" ).UNARY_COMPLEX(); |
283 | |
284 | REGISTER_OP("Expm1" ).UNARY_COMPLEX(); |
285 | |
286 | REGISTER_OP("Log" ).UNARY_COMPLEX(); |
287 | |
288 | REGISTER_OP("Log1p" ).UNARY_COMPLEX(); |
289 | |
290 | REGISTER_OP("Sinh" ).UNARY_COMPLEX(); |
291 | |
292 | REGISTER_OP("Cosh" ).UNARY_COMPLEX(); |
293 | |
294 | REGISTER_OP("Tanh" ).UNARY_COMPLEX(); |
295 | |
296 | REGISTER_OP("Asinh" ).UNARY_COMPLEX(); |
297 | |
298 | REGISTER_OP("Acosh" ).UNARY_COMPLEX(); |
299 | |
300 | REGISTER_OP("Atanh" ).UNARY_COMPLEX(); |
301 | |
302 | REGISTER_OP("TanhGrad" ).UNARY_GRADIENT_COMPLEX(); |
303 | |
304 | REGISTER_OP("Lgamma" ).UNARY_REAL(); |
305 | |
306 | REGISTER_OP("Digamma" ).UNARY_REAL(); |
307 | |
308 | REGISTER_OP("Erf" ).UNARY_REAL(); |
309 | REGISTER_OP("Erfinv" ).UNARY_REAL(); |
310 | REGISTER_OP("Ndtri" ).UNARY_REAL(); |
311 | REGISTER_OP("Erfc" ).UNARY_REAL(); |
312 | |
313 | REGISTER_OP("Sigmoid" ).UNARY_COMPLEX(); |
314 | |
315 | REGISTER_OP("SigmoidGrad" ).UNARY_GRADIENT_COMPLEX(); |
316 | |
317 | REGISTER_OP("Sin" ).UNARY_COMPLEX(); |
318 | |
319 | REGISTER_OP("Cos" ).UNARY_COMPLEX(); |
320 | |
321 | REGISTER_OP("Tan" ).UNARY(); |
322 | |
323 | REGISTER_OP("Asin" ).UNARY(); |
324 | |
325 | REGISTER_OP("Acos" ).UNARY(); |
326 | |
327 | REGISTER_OP("Atan" ).UNARY(); |
328 | |
329 | REGISTER_OP("_UnaryOpsComposition" ) |
330 | .Input("x: T" ) |
331 | .Output("y: T" ) |
332 | .Attr("T: {float, half, double}" ) |
333 | .Attr("op_names: list(string)" ) |
334 | .SetShapeFn(shape_inference::UnchangedShape) |
335 | .Doc(R"doc( |
336 | *NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is |
337 | expected to create these operators. |
338 | )doc" ); |
339 | |
340 | #undef UNARY |
341 | #undef UNARY_REAL |
342 | #undef UNARY_COMPLEX |
343 | |
344 | REGISTER_OP("IsNan" ) |
345 | .Input("x: T" ) |
346 | .Output("y: bool" ) |
347 | .Attr("T: {bfloat16, half, float, double}" ) |
348 | .SetShapeFn(shape_inference::UnchangedShape); |
349 | |
350 | REGISTER_OP("IsInf" ) |
351 | .Input("x: T" ) |
352 | .Output("y: bool" ) |
353 | .Attr("T: {bfloat16, half, float, double}" ) |
354 | .SetShapeFn(shape_inference::UnchangedShape); |
355 | |
356 | REGISTER_OP("IsFinite" ) |
357 | .Input("x: T" ) |
358 | .Output("y: bool" ) |
359 | .Attr("T: {bfloat16, half, float, double}" ) |
360 | .SetShapeFn(shape_inference::UnchangedShape); |
361 | |
362 | REGISTER_OP("Sign" ) |
363 | .Input("x: T" ) |
364 | .Output("y: T" ) |
365 | .Attr( |
366 | "T: {bfloat16, half, float, double, int8, int16, int32, int64, " |
367 | "complex64, complex128}" ) |
368 | .SetShapeFn(shape_inference::UnchangedShape); |
369 | |
370 | REGISTER_OP("Floor" ) |
371 | .Input("x: T" ) |
372 | .Output("y: T" ) |
373 | .Attr("T: {bfloat16, half, float, double}" ) |
374 | .SetShapeFn(shape_inference::UnchangedShape); |
375 | |
376 | REGISTER_OP("Ceil" ) |
377 | .Input("x: T" ) |
378 | .Output("y: T" ) |
379 | .Attr("T: {bfloat16, half, float, double}" ) |
380 | .SetShapeFn(shape_inference::UnchangedShape); |
381 | |
382 | REGISTER_OP("Rint" ) |
383 | .Input("x: T" ) |
384 | .Output("y: T" ) |
385 | .Attr("T: {bfloat16, half, float, double}" ) |
386 | .SetShapeFn(shape_inference::UnchangedShape); |
387 | |
388 | // Declares cwise binary operations signature: 't, 't -> 't. |
389 | |
390 | #define BINARY_MORE() \ |
391 | Input("x: T").Input("y: T").Output("z: T").Attr( \ |
392 | "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \ |
393 | "uint32, uint64, int64, complex64, complex128}") |
394 | |
395 | #define BINARY_FEWER() \ |
396 | Input("x: T").Input("y: T").Output("z: T").Attr( \ |
397 | "T: {bfloat16, half, float, double, int32, int64, complex64, " \ |
398 | "complex128}") |
399 | |
400 | REGISTER_OP("Add" ) |
401 | .Input("x: T" ) |
402 | .Input("y: T" ) |
403 | .Output("z: T" ) |
404 | .Attr( |
405 | "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " |
406 | "complex64, complex128, string}" ) |
407 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
408 | |
409 | REGISTER_OP("AddV2" ) |
410 | .Input("x: T" ) |
411 | .Input("y: T" ) |
412 | .Output("z: T" ) |
413 | .Attr( |
414 | "T: {bfloat16, half, float, double, uint8, uint16, uint32, uint64, " |
415 | "int8, int16, int32, int64, complex64, complex128}" ) |
416 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
417 | .SetIsAggregate() |
418 | .SetIsCommutative(); |
419 | |
420 | #ifdef INTEL_MKL |
421 | REGISTER_OP("_MklAdd" ) |
422 | .Input("x: T" ) |
423 | .Input("y: T" ) |
424 | .Input("mkl_x: uint8" ) |
425 | .Input("mkl_y: uint8" ) |
426 | .Output("z: T" ) |
427 | .Output("mkl_z: uint8" ) |
428 | .Attr( |
429 | "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, " |
430 | "complex128, string, bfloat16}" ) |
431 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
432 | .Doc(R"doc( |
433 | Returns `x` + `y` element-wise. |
434 | |
435 | *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting |
436 | [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). |
437 | )doc" ); |
438 | |
439 | REGISTER_OP("_MklAddV2" ) |
440 | .Input("x: T" ) |
441 | .Input("y: T" ) |
442 | .Input("mkl_x: uint8" ) |
443 | .Input("mkl_y: uint8" ) |
444 | .Output("z: T" ) |
445 | .Output("mkl_z: uint8" ) |
446 | .Attr( |
447 | "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, " |
448 | "complex64, complex128}" ) |
449 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
450 | .SetIsAggregate() |
451 | .SetIsCommutative() |
452 | .Doc(R"doc( |
453 | Returns `x` + `y` element-wise. |
454 | *NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting |
455 | [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html). |
456 | )doc" ); |
457 | #endif // INTEL_MKL |
458 | |
459 | REGISTER_OP("Sub" ) |
460 | .Input("x: T" ) |
461 | .Input("y: T" ) |
462 | .Output("z: T" ) |
463 | .Attr( |
464 | "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " |
465 | "int64, complex64, complex128, uint32, uint64}" ) |
466 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
467 | |
468 | REGISTER_OP("_MklSub" ) |
469 | .BINARY_FEWER() |
470 | .Input("mkl_x: uint8" ) |
471 | .Input("mkl_y: uint8" ) |
472 | .Output("mkl_z: uint8" ) |
473 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
474 | .Doc(R"doc( |
475 | Returns x - y element-wise. |
476 | |
477 | *NOTE*: `Sub` supports broadcasting. More about broadcasting |
478 | [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) |
479 | )doc" ); |
480 | |
481 | REGISTER_OP("Mul" ).BINARY_MORE().SetIsCommutative().SetShapeFn( |
482 | shape_inference::BroadcastBinaryOpShapeFn); |
483 | |
484 | REGISTER_OP("MulNoNan" ) |
485 | .Input("x: T" ) |
486 | .Input("y: T" ) |
487 | .Output("z: T" ) |
488 | .Attr("T: {bfloat16, half, float, double, complex64, complex128}" ) |
489 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
490 | |
491 | // Note: This op is not commutative w.r.t. to all its inputs. |
492 | REGISTER_OP("_MklMul" ) |
493 | .BINARY_MORE() |
494 | .Input("mkl_x: uint8" ) |
495 | .Input("mkl_y: uint8" ) |
496 | .Output("mkl_z: uint8" ) |
497 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
498 | .Doc(R"doc( |
499 | Returns x * y element-wise. |
500 | |
501 | *NOTE*: `Mul` supports broadcasting. More about broadcasting |
502 | [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) |
503 | )doc" ); |
504 | |
505 | REGISTER_OP("Div" ).BINARY_MORE().SetShapeFn( |
506 | shape_inference::BroadcastBinaryOpShapeFn); |
507 | |
508 | REGISTER_OP("DivNoNan" ) |
509 | .Input("x: T" ) |
510 | .Input("y: T" ) |
511 | .Output("z: T" ) |
512 | .Attr("T: {half, float, bfloat16, double, complex64, complex128}" ) |
513 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
514 | |
515 | REGISTER_OP("FloorDiv" ) |
516 | .BINARY_MORE() |
517 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
518 | |
519 | REGISTER_OP("TruncateDiv" ) |
520 | .BINARY_MORE() |
521 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
522 | |
523 | REGISTER_OP("RealDiv" ).BINARY_MORE().SetShapeFn( |
524 | shape_inference::BroadcastBinaryOpShapeFn); |
525 | |
526 | // Note SquaredDifference implements conj(x - y)*(x - y). |
527 | REGISTER_OP("SquaredDifference" ) |
528 | .BINARY_FEWER() |
529 | .SetIsCommutative() |
530 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
531 | |
532 | // Note: This op is not commutative w.r.t. to all its inputs. |
533 | REGISTER_OP("_MklSquaredDifference" ) |
534 | .BINARY_FEWER() |
535 | .Input("mkl_x: uint8" ) |
536 | .Input("mkl_y: uint8" ) |
537 | .Output("mkl_z: uint8" ) |
538 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
539 | .Doc(R"doc( |
540 | Returns (x - y)(x - y) element-wise. |
541 | |
542 | *NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting |
543 | [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) |
544 | )doc" ); |
545 | |
546 | REGISTER_OP("Xlogy" ) |
547 | .Input("x: T" ) |
548 | .Input("y: T" ) |
549 | .Output("z: T" ) |
550 | .Attr("T: {half, float, double, complex64, complex128}" ) |
551 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
552 | |
553 | REGISTER_OP("Xlog1py" ) |
554 | .Input("x: T" ) |
555 | .Input("y: T" ) |
556 | .Output("z: T" ) |
557 | .Attr("T: {half, float, double, complex64, complex128}" ) |
558 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
559 | |
560 | REGISTER_OP("Xdivy" ) |
561 | .Input("x: T" ) |
562 | .Input("y: T" ) |
563 | .Output("z: T" ) |
564 | .Attr("T: {half, float, double, complex64, complex128}" ) |
565 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
566 | |
567 | #undef BINARY_FEWER |
568 | #undef BINARY_MORE |
569 | |
570 | REGISTER_OP("Maximum" ) |
571 | .Input("x: T" ) |
572 | .Input("y: T" ) |
573 | .Output("z: T" ) |
574 | .Attr( |
575 | "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, " |
576 | "int32, uint32, int64, uint64}" ) |
577 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
578 | |
579 | // Note: This op is not commutative w.r.t. to all its inputs. |
580 | REGISTER_OP("_MklMaximum" ) |
581 | .Input("x: T" ) |
582 | .Input("y: T" ) |
583 | .Input("mkl_x: uint8" ) |
584 | .Input("mkl_y: uint8" ) |
585 | .Output("z: T" ) |
586 | .Output("mkl_z: uint8" ) |
587 | .Attr("T: {half, float, double, int32, int64, bfloat16}" ) |
588 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
589 | .Doc(R"doc( |
590 | Returns the max of x and y (i.e. x > y ? x : y) element-wise. |
591 | |
592 | *NOTE*: `Maximum` supports broadcasting. More about broadcasting |
593 | [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) |
594 | )doc" ); |
595 | |
596 | REGISTER_OP("Minimum" ) |
597 | .Input("x: T" ) |
598 | .Input("y: T" ) |
599 | .Output("z: T" ) |
600 | .Attr( |
601 | "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, " |
602 | "int32, uint32, int64, uint64}" ) |
603 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
604 | |
605 | REGISTER_OP("Mod" ) |
606 | .Input("x: T" ) |
607 | .Input("y: T" ) |
608 | .Output("z: T" ) |
609 | .Attr("T: {int32, int64, float16, half, bfloat16, float, double}" ) |
610 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
611 | |
612 | REGISTER_OP("FloorMod" ) |
613 | .Input("x: T" ) |
614 | .Input("y: T" ) |
615 | .Output("z: T" ) |
616 | .Attr( |
617 | "T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64, " |
618 | "bfloat16, half, float, double}" ) |
619 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
620 | |
621 | REGISTER_OP("TruncateMod" ) |
622 | .Input("x: T" ) |
623 | .Input("y: T" ) |
624 | .Output("z: T" ) |
625 | .Attr("T: {int32, int64, bfloat16, half, float, double}" ) |
626 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
627 | |
628 | REGISTER_OP("Pow" ) |
629 | .Input("x: T" ) |
630 | .Input("y: T" ) |
631 | .Output("z: T" ) |
632 | .Attr( |
633 | "T: {bfloat16, float, half, double, int8, int16, int32, int64, " |
634 | "complex64, complex128}" ) |
635 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
636 | |
637 | REGISTER_OP("Igammac" ) |
638 | .Input("a: T" ) |
639 | .Input("x: T" ) |
640 | .Output("z: T" ) |
641 | .Attr("T: {bfloat16, half, float, double}" ) |
642 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
643 | |
644 | REGISTER_OP("Igamma" ) |
645 | .Input("a: T" ) |
646 | .Input("x: T" ) |
647 | .Output("z: T" ) |
648 | .Attr("T: {bfloat16, half, float, double}" ) |
649 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
650 | |
651 | REGISTER_OP("IgammaGradA" ) |
652 | .Input("a: T" ) |
653 | .Input("x: T" ) |
654 | .Output("z: T" ) |
655 | .Attr("T: {float, double}" ) |
656 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
657 | |
658 | REGISTER_OP("Zeta" ) |
659 | .Input("x: T" ) |
660 | .Input("q: T" ) |
661 | .Output("z: T" ) |
662 | .Attr("T: {float, double}" ) |
663 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
664 | |
665 | REGISTER_OP("Polygamma" ) |
666 | .Input("a: T" ) |
667 | .Input("x: T" ) |
668 | .Output("z: T" ) |
669 | .Attr("T: {float, double}" ) |
670 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
671 | |
672 | REGISTER_OP("Atan2" ) |
673 | .Input("y: T" ) |
674 | .Input("x: T" ) |
675 | .Output("z: T" ) |
676 | .Attr("T: {bfloat16, half, float, double}" ) |
677 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
678 | |
679 | REGISTER_OP("Betainc" ) |
680 | .Input("a: T" ) |
681 | .Input("b: T" ) |
682 | .Input("x: T" ) |
683 | .Output("z: T" ) |
684 | .Attr("T: {float, double}" ) |
685 | .SetShapeFn([](InferenceContext* c) { |
686 | const int num_inputs = 3; |
687 | ShapeHandle output = c->UnknownShape(); |
688 | int num_scalars = 0; |
689 | ShapeHandle some_non_scalar; |
690 | for (int i = 0; i < num_inputs; ++i) { |
691 | ShapeHandle in = c->input(i); |
692 | if (!c->RankKnown(in)) { |
693 | some_non_scalar = in; |
694 | // An input with unknown rank could be either a scalar (to be |
695 | // broadcast) or some other shape. |
696 | } else if (c->Rank(in) == 0) { |
697 | // Input is a scalar, it will be broadcast to the output shape. |
698 | ++num_scalars; |
699 | } else { |
700 | TF_RETURN_IF_ERROR(c->Merge(output, in, &output)); |
701 | some_non_scalar = output; |
702 | } |
703 | } |
704 | |
705 | if (num_scalars == num_inputs - 1) { |
706 | // If all but one input is known to be a scalar, then output is the |
707 | // remaining input. |
708 | output = some_non_scalar; |
709 | } else if (num_scalars == num_inputs) { |
710 | // If all are scalars, output is scalar; pick the first one arbitrarily. |
711 | output = c->input(0); |
712 | } |
713 | |
714 | c->set_output(0, output); |
715 | return OkStatus(); |
716 | }); |
717 | |
718 | // -------------------------------------------------------------------------- |
719 | |
720 | // Declares cwise binary comparison operations signature: 't, 't -> bool, |
721 | // where 't has a natural total order. |
722 | #define COMPARISON() \ |
723 | Input("x: T") \ |
724 | .Input("y: T") \ |
725 | .Output("z: bool") \ |
726 | .Attr("T: realnumbertype") \ |
727 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
728 | |
729 | REGISTER_OP("Less" ).COMPARISON(); |
730 | |
731 | REGISTER_OP("LessEqual" ).COMPARISON(); |
732 | |
733 | REGISTER_OP("Greater" ).COMPARISON(); |
734 | |
735 | REGISTER_OP("GreaterEqual" ).COMPARISON(); |
736 | |
737 | #undef COMPARISON |
738 | |
739 | // -------------------------------------------------------------------------- |
740 | |
741 | #define EQUALITY_COMPARISON() \ |
742 | Input("x: T") \ |
743 | .Input("y: T") \ |
744 | .Output("z: bool") \ |
745 | .SetIsCommutative() \ |
746 | .Attr("T: type") \ |
747 | .Attr("incompatible_shape_error: bool = true") \ |
748 | .SetShapeFn([](InferenceContext* c) { \ |
749 | ShapeHandle x = c->input(0); \ |
750 | ShapeHandle y = c->input(1); \ |
751 | ShapeHandle output; \ |
752 | bool incompatible_shape_error; \ |
753 | TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \ |
754 | &incompatible_shape_error)); \ |
755 | TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \ |
756 | c, x, y, incompatible_shape_error, &output)); \ |
757 | c->set_output(0, output); \ |
758 | return OkStatus(); \ |
759 | }) |
760 | |
761 | REGISTER_OP("Equal" ).EQUALITY_COMPARISON(); |
762 | |
763 | REGISTER_OP("NotEqual" ).EQUALITY_COMPARISON(); |
764 | |
765 | #undef EQUALITY_COMPARISON |
766 | |
767 | REGISTER_OP("ApproximateEqual" ) |
768 | .Input("x: T" ) |
769 | .Input("y: T" ) |
770 | .Output("z: bool" ) |
771 | .SetIsCommutative() |
772 | .Attr("T: numbertype" ) |
773 | .Attr("tolerance: float = 0.00001" ) |
774 | .SetShapeFn([](InferenceContext* c) { |
775 | // The inputs 'x' and 'y' must have the same shape. |
776 | ShapeHandle data_x = c->input(0); |
777 | ShapeHandle data_y = c->input(1); |
778 | TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x)); |
779 | return shape_inference::UnchangedShape(c); |
780 | }); |
781 | |
782 | // -------------------------------------------------------------------------- |
783 | |
784 | REGISTER_OP("LogicalNot" ) |
785 | .Input("x: bool" ) |
786 | .Output("y: bool" ) |
787 | .SetShapeFn(shape_inference::UnchangedShape); |
788 | |
789 | #define BINARY_LOGICAL() \ |
790 | Input("x: bool") \ |
791 | .Input("y: bool") \ |
792 | .Output("z: bool") \ |
793 | .SetIsCommutative() \ |
794 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn) |
795 | |
796 | REGISTER_OP("LogicalAnd" ).BINARY_LOGICAL(); |
797 | |
798 | REGISTER_OP("LogicalOr" ).BINARY_LOGICAL(); |
799 | |
800 | #undef BINARY_LOGICAL |
801 | |
802 | // -------------------------------------------------------------------------- |
803 | |
804 | REGISTER_OP("Select" ) |
805 | .Input("condition: bool" ) |
806 | .Input("t: T" ) |
807 | .Input("e: T" ) |
808 | .Output("output: T" ) |
809 | .Attr("T: type" ) |
810 | .SetShapeFn([](InferenceContext* c) { |
811 | auto* handle_data_1 = c->input_handle_shapes_and_types(1); |
812 | auto* handle_data_2 = c->input_handle_shapes_and_types(2); |
813 | // Merge handle shape and dtype if applicable. |
814 | if (handle_data_1 != nullptr && handle_data_2 != nullptr) { |
815 | const auto size = handle_data_1->size(); |
816 | std::vector<shape_inference::ShapeAndType> merged_handle_data(size); |
817 | if (size != handle_data_2->size()) { |
818 | return errors::InvalidArgument( |
819 | "Trying to merge handles pointing to different numbers of " |
820 | "tensors." ); |
821 | } |
822 | |
823 | for (int i = 0; i < size; ++i) { |
824 | const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i]; |
825 | const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i]; |
826 | if (s1.dtype != s2.dtype) { |
827 | // TODO(apassos) resolve this in the manner of b/32476923 |
828 | return errors::InvalidArgument( |
829 | "Trying to merge handles pointing to different dtypes." ); |
830 | } |
831 | merged_handle_data[i].dtype = s1.dtype; |
832 | TF_RETURN_IF_ERROR( |
833 | c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape)); |
834 | } |
835 | |
836 | c->set_output_handle_shapes_and_types(0, merged_handle_data); |
837 | } |
838 | |
839 | // The inputs 'then' and 'else' must have the same shape. |
840 | ShapeHandle data = c->input(1); |
841 | ShapeHandle other = c->input(2); |
842 | TF_RETURN_IF_ERROR(c->Merge(data, other, &data)); |
843 | |
844 | // The input 'cond' must either have the same shape as 'then' and |
845 | // 'else', or be a vector if 'then' and 'else' are at least vectors. |
846 | ShapeHandle cond = c->input(0); |
847 | |
848 | if (!c->RankKnown(cond) || !c->RankKnown(data)) { |
849 | c->set_output(0, data); |
850 | return OkStatus(); |
851 | } |
852 | |
853 | // rank of shape and data is known. |
854 | |
855 | const int32_t cond_rank = c->Rank(cond); |
856 | const int32_t data_rank = c->Rank(data); |
857 | |
858 | if (cond_rank == 0) { |
859 | // The rank of 'cond' is a scalar. |
860 | // t and e can have any shape. |
861 | c->set_output(0, data); |
862 | return OkStatus(); |
863 | } |
864 | |
865 | if (cond_rank != 1) { |
866 | // If 'cond' is not a vector, and not a scalar, |
867 | // then shape must match 'then' and 'else' |
868 | TF_RETURN_IF_ERROR(c->Merge(data, cond, &data)); |
869 | c->set_output(0, data); |
870 | return OkStatus(); |
871 | } |
872 | |
873 | if (data_rank == 0) { |
874 | // if 'then' and 'else' are scalar also the cond must be |
875 | TF_RETURN_IF_ERROR(c->Merge(data, cond, &data)); |
876 | c->set_output(0, data); |
877 | return OkStatus(); |
878 | } |
879 | |
880 | if (cond_rank == 1) { |
881 | // if the cond is a vector and the 'then' is not a scalar, |
882 | // the first dimension of 'then' and 'else' |
883 | TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond)); |
884 | c->set_output(0, data); |
885 | return OkStatus(); |
886 | } |
887 | |
888 | c->set_output(0, data); |
889 | |
890 | return OkStatus(); |
891 | }); |
892 | |
893 | REGISTER_OP("SelectV2" ) |
894 | .Input("condition: bool" ) |
895 | .Input("t: T" ) |
896 | .Input("e: T" ) |
897 | .Output("output: T" ) |
898 | .Attr("T: type" ) |
899 | .SetShapeFn([](InferenceContext* c) { |
900 | auto* handle_data_1 = c->input_handle_shapes_and_types(1); |
901 | auto* handle_data_2 = c->input_handle_shapes_and_types(2); |
902 | // Merge handle shape and dtype if applicable. |
903 | if (handle_data_1 != nullptr && handle_data_2 != nullptr) { |
904 | const auto size = handle_data_1->size(); |
905 | std::vector<shape_inference::ShapeAndType> merged_handle_data(size); |
906 | if (size != handle_data_2->size()) { |
907 | return errors::InvalidArgument( |
908 | "Trying to merge handles pointing to different numbers of " |
909 | "tensors." ); |
910 | } |
911 | |
912 | for (int i = 0; i < size; ++i) { |
913 | const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i]; |
914 | const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i]; |
915 | if (s1.dtype != s2.dtype) { |
916 | // TODO(apassos) resolve this in the manner of b/32476923 |
917 | return errors::InvalidArgument( |
918 | "Trying to merge handles pointing to different dtypes." ); |
919 | } |
920 | merged_handle_data[i].dtype = s1.dtype; |
921 | TF_RETURN_IF_ERROR( |
922 | c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape)); |
923 | } |
924 | |
925 | c->set_output_handle_shapes_and_types(0, merged_handle_data); |
926 | } |
927 | |
928 | // The inputs 'cond', 'then', and 'else' must be broadcastable. |
929 | // TODO (yongtang): Consolidate 3-ary broadcast instead of |
930 | // multiple 2-ary broadcast. |
931 | ShapeHandle cond = c->input(0); |
932 | ShapeHandle then = c->input(1); |
933 | ShapeHandle else_ = c->input(2); |
934 | ShapeHandle other; |
935 | TF_RETURN_IF_ERROR( |
936 | BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other)); |
937 | ShapeHandle output; |
938 | TF_RETURN_IF_ERROR( |
939 | BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output)); |
940 | c->set_output(0, output); |
941 | return OkStatus(); |
942 | }); |
943 | |
944 | // -------------------------------------------------------------------------- |
945 | |
946 | REGISTER_OP("MatMul" ) |
947 | .Input("a: T" ) |
948 | .Input("b: T" ) |
949 | .Output("product: T" ) |
950 | .Attr("transpose_a: bool = false" ) |
951 | .Attr("transpose_b: bool = false" ) |
952 | .Attr( |
953 | "T: {bfloat16, half, float, double, int32, int64, complex64, " |
954 | "complex128}" ) |
955 | .SetShapeFn(shape_inference::MatMulShape); |
956 | |
957 | #ifdef INTEL_MKL |
958 | REGISTER_OP("_MklMatMul" ) |
959 | .Input("a: T" ) |
960 | .Input("b: T" ) |
961 | .Output("product: T" ) |
962 | .Attr("transpose_a: bool = false" ) |
963 | .Attr("transpose_b: bool = false" ) |
964 | .Attr("T: {bfloat16, float}" ) |
965 | .SetShapeFn(shape_inference::MatMulShape); |
966 | #endif // INTEL_MKL |
967 | |
968 | REGISTER_OP("SparseMatMul" ) |
969 | .Input("a: Ta" ) |
970 | .Input("b: Tb" ) |
971 | .Output("product: float" ) |
972 | .Attr("transpose_a: bool = false" ) |
973 | .Attr("transpose_b: bool = false" ) |
974 | .Attr("a_is_sparse: bool = false" ) |
975 | .Attr("b_is_sparse: bool = false" ) |
976 | .Attr("Ta: {float, bfloat16} = DT_FLOAT" ) |
977 | .Attr("Tb: {float, bfloat16} = DT_FLOAT" ) |
978 | .SetShapeFn(shape_inference::MatMulShape); |
979 | |
980 | REGISTER_OP("_FusedMatMul" ) |
981 | .Input("a: T" ) |
982 | .Input("b: T" ) |
983 | .Input("args: num_args * T" ) |
984 | .Output("product: T" ) |
985 | .Attr("transpose_a: bool = false" ) |
986 | .Attr("transpose_b: bool = false" ) |
987 | .Attr("T: {bfloat16, half, float}" ) |
988 | .Attr("num_args: int >= 0" ) |
989 | .Attr("fused_ops: list(string) = []" ) |
990 | // Attributes for the FusedBatchNorm ----------- // |
991 | .Attr("epsilon: float = 0.0001" ) |
992 | // Attributes for the LeakyRelu ---------------- // |
993 | .Attr("leakyrelu_alpha: float = 0.2" ) |
994 | // --------------------------------------------- // |
995 | .SetShapeFn(shape_inference::MatMulShape) |
996 | .Doc(R"doc( |
997 | Performs a MatMul followed by a specified series of operations. |
998 | |
999 | The inputs to the MatMul are specified by `a` and `b`. The series of operations |
1000 | that follows is specified by the `fused_ops` attribute, which is a list of TF op |
1001 | names specified as strings (e.g. "Relu"). They are performed in order, where the |
1002 | (first) input to each op is the output of the preceding op. The first input and |
1003 | the output of each fused_op must be of type T. |
1004 | |
1005 | Currently supported fused_op combinations are: ["BiasAdd"] and ["BiasAdd",A], |
1006 | where A is one of {"Elu","Relu","Relu6"}. |
1007 | |
1008 | * The first input to BiasAdd is the MatMul result, and the additional BiasAdd |
1009 | input is specified by `args`. |
1010 | * If there is an op A specified, the output of the BiasAdd is the input to op A, |
1011 | and op A produces the _FusedConv2D output. Otherwise, the BiasAdd produces the |
1012 | _FusedConv2D output. |
1013 | |
1014 | *NOTE*: Do not invoke this operator directly in Python. Grappler is |
1015 | expected to create these operators. |
1016 | )doc" ); |
1017 | |
1018 | // -------------------------------------------------------------------------- |
1019 | |
1020 | // For operations where the output is a reduction function along some |
1021 | // dimensions of the input. |
1022 | REGISTER_OP("Sum" ) |
1023 | .Input("input: T" ) |
1024 | .Input("reduction_indices: Tidx" ) |
1025 | .Output("output: T" ) |
1026 | .Attr("keep_dims: bool = false" ) |
1027 | .Attr("T: numbertype" ) |
1028 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1029 | .SetShapeFn(shape_inference::ReductionShape); |
1030 | |
1031 | REGISTER_OP("EuclideanNorm" ) |
1032 | .Input("input: T" ) |
1033 | .Input("reduction_indices: Tidx" ) |
1034 | .Output("output: T" ) |
1035 | .Attr("keep_dims: bool = false" ) |
1036 | .Attr("T: numbertype" ) |
1037 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1038 | .SetShapeFn(shape_inference::ReductionShape); |
1039 | |
1040 | REGISTER_OP("Mean" ) |
1041 | .Input("input: T" ) |
1042 | .Input("reduction_indices: Tidx" ) |
1043 | .Output("output: T" ) |
1044 | .Attr("keep_dims: bool = false" ) |
1045 | .Attr("T: numbertype" ) |
1046 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1047 | .SetShapeFn(shape_inference::ReductionShape); |
1048 | |
1049 | REGISTER_OP("Prod" ) |
1050 | .Input("input: T" ) |
1051 | .Input("reduction_indices: Tidx" ) |
1052 | .Output("output: T" ) |
1053 | .Attr("keep_dims: bool = false" ) |
1054 | .Attr("T: numbertype" ) |
1055 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1056 | .SetShapeFn(shape_inference::ReductionShape); |
1057 | |
1058 | REGISTER_OP("Min" ) |
1059 | .Input("input: T" ) |
1060 | .Input("reduction_indices: Tidx" ) |
1061 | .Output("output: T" ) |
1062 | .Attr("keep_dims: bool = false" ) |
1063 | .Attr("T: {realnumbertype, quantizedtype}" ) |
1064 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1065 | .SetShapeFn(shape_inference::ReductionShape); |
1066 | |
1067 | REGISTER_OP("Max" ) |
1068 | .Input("input: T" ) |
1069 | .Input("reduction_indices: Tidx" ) |
1070 | .Output("output: T" ) |
1071 | .Attr("keep_dims: bool = false" ) |
1072 | .Attr("T: {realnumbertype, quantizedtype}" ) |
1073 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1074 | .SetShapeFn(shape_inference::ReductionShape); |
1075 | |
1076 | namespace { |
1077 | |
1078 | Status ArgOpShape(shape_inference::InferenceContext* c) { |
1079 | ShapeHandle dimension_shape; |
1080 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape)); |
1081 | |
1082 | ShapeHandle input_shape = c->input(0); |
1083 | if (!c->RankKnown(input_shape)) { |
1084 | return shape_inference::UnknownShape(c); |
1085 | } |
1086 | |
1087 | const int32_t input_rank = c->Rank(input_shape); |
1088 | if (input_rank <= 1) { |
1089 | // Reducing a scalar/vector must return a scalar. |
1090 | return shape_inference::ScalarShape(c); |
1091 | } |
1092 | |
1093 | const Tensor* dim_t = c->input_tensor(1); |
1094 | if (dim_t == nullptr) { |
1095 | // We don't know the value of the dimension, but we |
1096 | // know the rank of the input, so return the correct |
1097 | // rank with unknown dimensions. |
1098 | std::vector<DimensionHandle> dims(input_rank - 1); |
1099 | for (int i = 0; i < dims.size(); ++i) { |
1100 | dims[i] = c->UnknownDim(); |
1101 | } |
1102 | |
1103 | c->set_output(0, c->MakeShape(dims)); |
1104 | return OkStatus(); |
1105 | } |
1106 | |
1107 | int64_t dimension_val; |
1108 | if (dim_t->dtype() == DT_INT32) { |
1109 | dimension_val = dim_t->scalar<int32>()(); |
1110 | } else { |
1111 | dimension_val = dim_t->scalar<int64_t>()(); |
1112 | } |
1113 | |
1114 | int64_t axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val; |
1115 | if (axis < 0 || axis >= input_rank) { |
1116 | return errors::InvalidArgument( |
1117 | "Dimension (" , dimension_val, ") must be in the range [" , -input_rank, |
1118 | ", " , input_rank, "), where " , input_rank, |
1119 | " is the number of dimensions in the input." ); |
1120 | } |
1121 | |
1122 | // Return the input shape without the dimension being reduced. |
1123 | std::vector<DimensionHandle> dims; |
1124 | for (int i = 0; i < input_rank; ++i) { |
1125 | if (axis != i) { |
1126 | dims.emplace_back(c->Dim(input_shape, i)); |
1127 | } |
1128 | } |
1129 | c->set_output(0, c->MakeShape(dims)); |
1130 | return OkStatus(); |
1131 | } |
1132 | |
1133 | } // namespace |
1134 | |
1135 | REGISTER_OP("ArgMax" ) |
1136 | .Input("input: T" ) |
1137 | .Input("dimension: Tidx" ) |
1138 | .Output("output: output_type" ) |
1139 | .Attr("T: {numbertype, bool}" ) |
1140 | .Attr("Tidx: {int16, int32, int64} = DT_INT32" ) |
1141 | .Attr("output_type: {int16, uint16, int32, int64} = DT_INT64" ) |
1142 | .SetShapeFn(ArgOpShape); |
1143 | |
1144 | REGISTER_OP("ArgMin" ) |
1145 | .Input("input: T" ) |
1146 | .Input("dimension: Tidx" ) |
1147 | .Output("output: output_type" ) |
1148 | .Attr("T: {numbertype, bool}" ) |
1149 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1150 | .Attr("output_type: {int32, int64} = DT_INT64" ) |
1151 | .SetShapeFn(ArgOpShape); |
1152 | |
1153 | namespace { |
1154 | |
1155 | Status SegmentReductionShapeFn(InferenceContext* c) { |
1156 | ShapeHandle data_shape; |
1157 | ShapeHandle segment_ids_shape; |
1158 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); |
1159 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape)); |
1160 | |
1161 | ShapeHandle subshape; |
1162 | TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); |
1163 | |
1164 | ShapeHandle out; |
1165 | TF_RETURN_IF_ERROR( |
1166 | c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out)); |
1167 | c->set_output(0, out); |
1168 | return OkStatus(); |
1169 | } |
1170 | |
1171 | Status SparseSegmentReductionShapeFn(InferenceContext* c) { |
1172 | ShapeHandle data_shape; |
1173 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); |
1174 | |
1175 | ShapeHandle indices_shape; |
1176 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); |
1177 | |
1178 | ShapeHandle segment_ids_shape; |
1179 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape)); |
1180 | |
1181 | // indices and segment_ids should merge cleanly. |
1182 | ShapeHandle unused; |
1183 | TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused)); |
1184 | |
1185 | ShapeHandle subshape; |
1186 | TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); |
1187 | |
1188 | ShapeHandle out; |
1189 | TF_RETURN_IF_ERROR( |
1190 | c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out)); |
1191 | c->set_output(0, out); |
1192 | return OkStatus(); |
1193 | } |
1194 | |
1195 | Status SparseSegmentReductionGradShapeFn(InferenceContext* c) { |
1196 | ShapeHandle data_shape; |
1197 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); |
1198 | |
1199 | ShapeHandle indices_shape; |
1200 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); |
1201 | |
1202 | // indices and segment_ids should merge cleanly. |
1203 | ShapeHandle unused; |
1204 | TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused)); |
1205 | |
1206 | // output_dim0 should be a scalar |
1207 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1208 | |
1209 | ShapeHandle subshape; |
1210 | TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); |
1211 | |
1212 | const Tensor* dim0 = c->input_tensor(3); |
1213 | ShapeHandle dim0_shape; |
1214 | if (dim0 == nullptr) { |
1215 | // We don't have the value at inference time, so the output |
1216 | // shape is unknown. |
1217 | dim0_shape = c->Vector(InferenceContext::kUnknownDim); |
1218 | } else { |
1219 | auto dim0_value = dim0->scalar<int32>()(); |
1220 | if (dim0_value < 0) { |
1221 | return errors::InvalidArgument( |
1222 | "Cannot specify a negative value for output_dim0" ); |
1223 | } |
1224 | dim0_shape = c->Vector(dim0_value); |
1225 | } |
1226 | |
1227 | ShapeHandle out; |
1228 | TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out)); |
1229 | c->set_output(0, out); |
1230 | return OkStatus(); |
1231 | } |
1232 | |
1233 | Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) { |
1234 | ShapeHandle data_shape; |
1235 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape)); |
1236 | |
1237 | ShapeHandle indices_shape; |
1238 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape)); |
1239 | |
1240 | ShapeHandle segment_ids_shape; |
1241 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape)); |
1242 | |
1243 | ShapeHandle num_segments_shape; |
1244 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape)); |
1245 | |
1246 | // indices and segment_ids should merge cleanly. |
1247 | ShapeHandle unused; |
1248 | TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused)); |
1249 | |
1250 | ShapeHandle subshape; |
1251 | TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape)); |
1252 | |
1253 | ShapeHandle out; |
1254 | const Tensor* dim0 = c->input_tensor(3); |
1255 | if (dim0 == nullptr) { |
1256 | // We don't have the value at inference time, so the output |
1257 | // shape is unknown. |
1258 | TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim), |
1259 | subshape, &out)); |
1260 | } else { |
1261 | auto dim0_value = dim0->scalar<int32>()(); |
1262 | if (dim0_value < 0) { |
1263 | return errors::InvalidArgument( |
1264 | "Cannot specify a negative value for num_segments" ); |
1265 | } |
1266 | TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out)); |
1267 | } |
1268 | c->set_output(0, out); |
1269 | return OkStatus(); |
1270 | } |
1271 | } // namespace |
1272 | |
1273 | REGISTER_OP("SegmentSum" ) |
1274 | .Input("data: T" ) |
1275 | .Input("segment_ids: Tindices" ) |
1276 | .Output("output: T" ) |
1277 | .Attr("T: numbertype" ) |
1278 | .Attr("Tindices: {int32,int64}" ) |
1279 | .SetShapeFn(SegmentReductionShapeFn); |
1280 | |
1281 | REGISTER_OP("SegmentMean" ) |
1282 | .Input("data: T" ) |
1283 | .Input("segment_ids: Tindices" ) |
1284 | .Output("output: T" ) |
1285 | .Attr("T: numbertype" ) |
1286 | .Attr("Tindices: {int32,int64}" ) |
1287 | .SetShapeFn(SegmentReductionShapeFn); |
1288 | |
1289 | REGISTER_OP("SegmentProd" ) |
1290 | .Input("data: T" ) |
1291 | .Input("segment_ids: Tindices" ) |
1292 | .Output("output: T" ) |
1293 | .Attr("T: numbertype" ) |
1294 | .Attr("Tindices: {int32,int64}" ) |
1295 | .SetShapeFn(SegmentReductionShapeFn); |
1296 | |
1297 | REGISTER_OP("SegmentMin" ) |
1298 | .Input("data: T" ) |
1299 | .Input("segment_ids: Tindices" ) |
1300 | .Output("output: T" ) |
1301 | .Attr("T: realnumbertype" ) |
1302 | .Attr("Tindices: {int32,int64}" ) |
1303 | .SetShapeFn(SegmentReductionShapeFn); |
1304 | |
1305 | REGISTER_OP("SegmentMax" ) |
1306 | .Input("data: T" ) |
1307 | .Input("segment_ids: Tindices" ) |
1308 | .Output("output: T" ) |
1309 | .Attr("T: realnumbertype" ) |
1310 | .Attr("Tindices: {int32,int64}" ) |
1311 | .SetShapeFn(SegmentReductionShapeFn); |
1312 | |
1313 | REGISTER_OP("UnsortedSegmentSum" ) |
1314 | .Input("data: T" ) |
1315 | .Input("segment_ids: Tindices" ) |
1316 | .Input("num_segments: Tnumsegments" ) |
1317 | .Output("output: T" ) |
1318 | .Attr("T: numbertype" ) |
1319 | .Attr("Tindices: {int32,int64}" ) |
1320 | .Attr("Tnumsegments: {int32,int64} = DT_INT32" ) |
1321 | .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn); |
1322 | |
1323 | REGISTER_OP("UnsortedSegmentMax" ) |
1324 | .Input("data: T" ) |
1325 | .Input("segment_ids: Tindices" ) |
1326 | .Input("num_segments: Tnumsegments" ) |
1327 | .Output("output: T" ) |
1328 | .Attr("T: realnumbertype" ) |
1329 | .Attr("Tindices: {int32,int64}" ) |
1330 | .Attr("Tnumsegments: {int32,int64} = DT_INT32" ) |
1331 | .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn); |
1332 | |
1333 | REGISTER_OP("UnsortedSegmentMin" ) |
1334 | .Input("data: T" ) |
1335 | .Input("segment_ids: Tindices" ) |
1336 | .Input("num_segments: Tnumsegments" ) |
1337 | .Output("output: T" ) |
1338 | .Attr("T: realnumbertype" ) |
1339 | .Attr("Tindices: {int32,int64}" ) |
1340 | .Attr("Tnumsegments: {int32,int64} = DT_INT32" ) |
1341 | .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn); |
1342 | |
1343 | REGISTER_OP("UnsortedSegmentProd" ) |
1344 | .Input("data: T" ) |
1345 | .Input("segment_ids: Tindices" ) |
1346 | .Input("num_segments: Tnumsegments" ) |
1347 | .Output("output: T" ) |
1348 | .Attr("T: numbertype" ) |
1349 | .Attr("Tindices: {int32,int64}" ) |
1350 | .Attr("Tnumsegments: {int32,int64} = DT_INT32" ) |
1351 | .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn); |
1352 | |
1353 | REGISTER_OP("SparseSegmentSum" ) |
1354 | .Input("data: T" ) |
1355 | .Input("indices: Tidx" ) |
1356 | .Input("segment_ids: Tsegmentids" ) |
1357 | .Output("output: T" ) |
1358 | .Attr("T: realnumbertype" ) |
1359 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1360 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1361 | .SetShapeFn(SparseSegmentReductionShapeFn); |
1362 | |
1363 | REGISTER_OP("SparseSegmentSumWithNumSegments" ) |
1364 | .Input("data: T" ) |
1365 | .Input("indices: Tidx" ) |
1366 | .Input("segment_ids: Tsegmentids" ) |
1367 | .Input("num_segments: Tnumsegments" ) |
1368 | .Output("output: T" ) |
1369 | .Attr("T: realnumbertype" ) |
1370 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1371 | .Attr("Tnumsegments: {int32,int64} = DT_INT32" ) |
1372 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1373 | .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); |
1374 | |
1375 | REGISTER_OP("SparseSegmentSumGrad" ) |
1376 | .Input("grad: T" ) |
1377 | .Input("indices: Tidx" ) |
1378 | .Input("segment_ids: Tsegmentids" ) |
1379 | .Input("output_dim0: int32" ) |
1380 | .Output("output: T" ) |
1381 | .Attr("T: {bfloat16, half, float, double}" ) |
1382 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1383 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1384 | .SetShapeFn(SparseSegmentReductionGradShapeFn); |
1385 | |
1386 | REGISTER_OP("SparseSegmentMean" ) |
1387 | .Input("data: T" ) |
1388 | .Input("indices: Tidx" ) |
1389 | .Input("segment_ids: Tsegmentids" ) |
1390 | .Output("output: T" ) |
1391 | .Attr("T: {bfloat16, half, float, double}" ) |
1392 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1393 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1394 | .SetShapeFn(SparseSegmentReductionShapeFn); |
1395 | |
1396 | REGISTER_OP("SparseSegmentMeanWithNumSegments" ) |
1397 | .Input("data: T" ) |
1398 | .Input("indices: Tidx" ) |
1399 | .Input("segment_ids: Tsegmentids" ) |
1400 | .Input("num_segments: Tnumsegments" ) |
1401 | .Output("output: T" ) |
1402 | .Attr("T: {bfloat16, half, float, double}" ) |
1403 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1404 | .Attr("Tnumsegments: {int32,int64} = DT_INT32" ) |
1405 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1406 | .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); |
1407 | |
1408 | REGISTER_OP("SparseSegmentMeanGrad" ) |
1409 | .Input("grad: T" ) |
1410 | .Input("indices: Tidx" ) |
1411 | .Input("segment_ids: Tsegmentids" ) |
1412 | .Input("output_dim0: int32" ) |
1413 | .Output("output: T" ) |
1414 | .Attr("T: {bfloat16, half, float, double}" ) |
1415 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1416 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1417 | .SetShapeFn(SparseSegmentReductionGradShapeFn); |
1418 | |
1419 | REGISTER_OP("SparseSegmentSqrtN" ) |
1420 | .Input("data: T" ) |
1421 | .Input("indices: Tidx" ) |
1422 | .Input("segment_ids: Tsegmentids" ) |
1423 | .Output("output: T" ) |
1424 | .Attr("T: {bfloat16, half, float, double}" ) |
1425 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1426 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1427 | .SetShapeFn(SparseSegmentReductionShapeFn); |
1428 | |
1429 | REGISTER_OP("SparseSegmentSqrtNWithNumSegments" ) |
1430 | .Input("data: T" ) |
1431 | .Input("indices: Tidx" ) |
1432 | .Input("segment_ids: Tsegmentids" ) |
1433 | .Input("num_segments: Tnumsegments" ) |
1434 | .Output("output: T" ) |
1435 | .Attr("T: {bfloat16, half, float, double}" ) |
1436 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1437 | .Attr("Tnumsegments: {int32,int64} = DT_INT32" ) |
1438 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1439 | .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn); |
1440 | |
1441 | REGISTER_OP("SparseSegmentSqrtNGrad" ) |
1442 | .Input("grad: T" ) |
1443 | .Input("indices: Tidx" ) |
1444 | .Input("segment_ids: Tsegmentids" ) |
1445 | .Input("output_dim0: int32" ) |
1446 | .Output("output: T" ) |
1447 | .Attr("T: {bfloat16, half, float, double}" ) |
1448 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1449 | .Attr("Tsegmentids: {int32, int64} = DT_INT32" ) |
1450 | .SetShapeFn(SparseSegmentReductionGradShapeFn); |
1451 | |
1452 | REGISTER_OP("All" ) |
1453 | .Input("input: bool" ) |
1454 | .Input("reduction_indices: Tidx" ) |
1455 | .Output("output: bool" ) |
1456 | .Attr("keep_dims: bool = false" ) |
1457 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1458 | .SetShapeFn(shape_inference::ReductionShape); |
1459 | |
1460 | REGISTER_OP("Any" ) |
1461 | .Input("input: bool" ) |
1462 | .Input("reduction_indices: Tidx" ) |
1463 | .Attr("keep_dims: bool = false" ) |
1464 | .Output("output: bool" ) |
1465 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1466 | .SetShapeFn(shape_inference::ReductionShape); |
1467 | |
1468 | // -------------------------------------------------------------------------- |
1469 | |
1470 | namespace { |
1471 | |
1472 | template <typename T> |
1473 | Status RangeSize(const Tensor* start_t, const Tensor* limit_t, |
1474 | const Tensor* delta_t, InferenceContext* const c) { |
1475 | T start = start_t->scalar<T>()(); |
1476 | T limit = limit_t->scalar<T>()(); |
1477 | T delta = delta_t->scalar<T>()(); |
1478 | if (start > limit && delta > T(0)) { |
1479 | return errors::InvalidArgument( |
1480 | "Requires start <= limit when delta > 0: " , start, "/" , limit); |
1481 | } |
1482 | if (start < limit && delta < T(0)) { |
1483 | return errors::InvalidArgument( |
1484 | "Requires start >= limit when delta < 0: " , start, "/" , limit); |
1485 | } |
1486 | if (delta == T(0)) { |
1487 | return errors::InvalidArgument("Requires delta != 0" ); |
1488 | } |
1489 | |
1490 | int64_t size; |
1491 | if (std::is_integral<T>::value) { |
1492 | size = Eigen::divup(static_cast<int64_t>(Eigen::numext::abs(limit - start)), |
1493 | static_cast<int64_t>(Eigen::numext::abs(delta))); |
1494 | } else { |
1495 | auto size_auto = |
1496 | Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta)); |
1497 | if (size_auto > std::numeric_limits<int64_t>::max()) { |
1498 | return errors::InvalidArgument("Requires ((limit - start) / delta) <= " , |
1499 | std::numeric_limits<int64_t>::max()); |
1500 | } |
1501 | size = static_cast<int64_t>(size_auto); |
1502 | } |
1503 | |
1504 | c->set_output(0, c->Vector(static_cast<int64_t>(size))); |
1505 | return OkStatus(); |
1506 | } |
1507 | |
1508 | } // namespace |
1509 | |
1510 | REGISTER_OP("Range" ) |
1511 | .Input("start: Tidx" ) |
1512 | .Input("limit: Tidx" ) |
1513 | .Input("delta: Tidx" ) |
1514 | .Output("output: Tidx" ) |
1515 | .Attr( |
1516 | "Tidx: " |
1517 | "{bfloat16, half, float, double, int8, int16, int32, int64, uint16, " |
1518 | "uint32} = " |
1519 | "DT_INT32" ) |
1520 | .SetShapeFn([](InferenceContext* c) { |
1521 | ShapeHandle unused; |
1522 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), |
1523 | " for 'start'" ); |
1524 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused), |
1525 | " for 'limit'" ); |
1526 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused), |
1527 | " for 'delta'" ); |
1528 | const Tensor* start_t = c->input_tensor(0); |
1529 | const Tensor* limit_t = c->input_tensor(1); |
1530 | const Tensor* delta_t = c->input_tensor(2); |
1531 | DataType dtype; |
1532 | TF_RETURN_IF_ERROR(c->GetAttr("Tidx" , &dtype)); |
1533 | if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) { |
1534 | c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
1535 | return OkStatus(); |
1536 | } |
1537 | if (dtype == DT_INT32) { |
1538 | return RangeSize<int32>(start_t, limit_t, delta_t, c); |
1539 | } else if (dtype == DT_INT16) { |
1540 | return RangeSize<int16>(start_t, limit_t, delta_t, c); |
1541 | } else if (dtype == DT_INT8) { |
1542 | return RangeSize<int8>(start_t, limit_t, delta_t, c); |
1543 | } else if (dtype == DT_INT64) { |
1544 | return RangeSize<int64_t>(start_t, limit_t, delta_t, c); |
1545 | } else if (dtype == DT_UINT16) { |
1546 | return RangeSize<uint16>(start_t, limit_t, delta_t, c); |
1547 | } else if (dtype == DT_UINT32) { |
1548 | return RangeSize<uint32>(start_t, limit_t, delta_t, c); |
1549 | } else if (dtype == DT_FLOAT) { |
1550 | return RangeSize<float>(start_t, limit_t, delta_t, c); |
1551 | } else if (dtype == DT_DOUBLE) { |
1552 | return RangeSize<double>(start_t, limit_t, delta_t, c); |
1553 | } else if (dtype == DT_BFLOAT16) { |
1554 | return RangeSize<bfloat16>(start_t, limit_t, delta_t, c); |
1555 | } else { |
1556 | return errors::InvalidArgument("Unsupported dtype" , dtype); |
1557 | } |
1558 | return OkStatus(); |
1559 | }); |
1560 | |
1561 | REGISTER_OP("LinSpace" ) |
1562 | .Input("start: T" ) |
1563 | .Input("stop: T" ) |
1564 | .Input("num: Tidx" ) |
1565 | .Output("output: T" ) |
1566 | .Attr("T: {bfloat16, half, float, double}" ) |
1567 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1568 | .SetShapeFn([](InferenceContext* c) { |
1569 | ShapeHandle unused; |
1570 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused), |
1571 | " for 'start'" ); |
1572 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused), |
1573 | " for 'stop'" ); |
1574 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused), |
1575 | " for 'num'" ); |
1576 | const Tensor* num_t = c->input_tensor(2); |
1577 | if (num_t == nullptr) { |
1578 | c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
1579 | return OkStatus(); |
1580 | } |
1581 | |
1582 | int64_t num; |
1583 | if (num_t->dtype() == DT_INT32) { |
1584 | num = num_t->scalar<int32>()(); |
1585 | } else { |
1586 | num = num_t->scalar<int64_t>()(); |
1587 | } |
1588 | if (num <= 0) return errors::InvalidArgument("Requires num > 0: " , num); |
1589 | c->set_output(0, c->Vector(num)); |
1590 | return OkStatus(); |
1591 | }); |
1592 | |
1593 | REGISTER_OP("Complex" ) |
1594 | .Input("real: T" ) |
1595 | .Input("imag: T" ) |
1596 | .Output("out: Tout" ) |
1597 | .Attr("T: {float, double} = DT_FLOAT" ) |
1598 | .Attr("Tout: {complex64, complex128} = DT_COMPLEX64" ) |
1599 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
1600 | |
1601 | REGISTER_OP("Real" ) |
1602 | .Input("input: T" ) |
1603 | .Output("output: Tout" ) |
1604 | .Attr("T: {complex64, complex128} = DT_COMPLEX64" ) |
1605 | .Attr("Tout: {float, double} = DT_FLOAT" ) |
1606 | .SetShapeFn(shape_inference::UnchangedShape); |
1607 | |
1608 | REGISTER_OP("Imag" ) |
1609 | .Input("input: T" ) |
1610 | .Output("output: Tout" ) |
1611 | .Attr("T: {complex64, complex128} = DT_COMPLEX64" ) |
1612 | .Attr("Tout: {float, double} = DT_FLOAT" ) |
1613 | .SetShapeFn(shape_inference::UnchangedShape); |
1614 | |
1615 | REGISTER_OP("Angle" ) |
1616 | .Input("input: T" ) |
1617 | .Output("output: Tout" ) |
1618 | .Attr("T: {complex64, complex128} = DT_COMPLEX64" ) |
1619 | .Attr("Tout: {float, double} = DT_FLOAT" ) |
1620 | .SetShapeFn(shape_inference::UnchangedShape); |
1621 | |
1622 | REGISTER_OP("Conj" ) |
1623 | .Input("input: T" ) |
1624 | .Output("output: T" ) |
1625 | .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64" ) |
1626 | .SetShapeFn([](InferenceContext* c) { |
1627 | c->set_output(0, c->input(0)); |
1628 | auto* handle_data = c->input_handle_shapes_and_types(0); |
1629 | if (handle_data != nullptr) { |
1630 | c->set_output_handle_shapes_and_types(0, *handle_data); |
1631 | } |
1632 | return OkStatus(); |
1633 | }); |
1634 | |
1635 | // -------------------------------------------------------------------------- |
1636 | |
1637 | REGISTER_OP("Cross" ) |
1638 | .Input("a: T" ) |
1639 | .Input("b: T" ) |
1640 | .Output("product: T" ) |
1641 | .Attr("T: realnumbertype" ) |
1642 | .SetShapeFn([](InferenceContext* c) { |
1643 | ShapeHandle a_shape; |
1644 | ShapeHandle b_shape; |
1645 | // * Input rank >= 1. |
1646 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape)); |
1647 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape)); |
1648 | |
1649 | // * Both inputs have the same shape. |
1650 | TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape)); |
1651 | |
1652 | // * input_shape[-1] == 3. |
1653 | if (c->RankKnown(a_shape)) { |
1654 | int rank = c->Rank(a_shape); |
1655 | auto dim = c->Dim(a_shape, rank - 1); |
1656 | TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim)); |
1657 | } |
1658 | c->set_output(0, a_shape); |
1659 | return OkStatus(); |
1660 | }); |
1661 | |
1662 | // -------------------------------------------------------------------------- |
1663 | |
1664 | REGISTER_OP("HistogramFixedWidth" ) |
1665 | .Input("values: T" ) |
1666 | .Input("value_range: T" ) |
1667 | .Input("nbins: int32" ) |
1668 | .Output("out: dtype" ) |
1669 | .Attr("T: {int32, int64, float32, float64}" ) |
1670 | .Attr("dtype: {int32, int64} = DT_INT32" ) |
1671 | .SetShapeFn([](InferenceContext* c) { |
1672 | // value_range should be a vector. |
1673 | ShapeHandle value_range_shape; |
1674 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape)); |
1675 | // value_range should have two elements. |
1676 | DimensionHandle unused; |
1677 | TF_RETURN_IF_ERROR( |
1678 | c->WithValue(c->Dim(value_range_shape, 0), 2, &unused)); |
1679 | // nbins should be a scalar. |
1680 | ShapeHandle nbins_shape; |
1681 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape)); |
1682 | |
1683 | // If nbins is available, set the shape from nbins. |
1684 | const Tensor* nbins_input = c->input_tensor(2); |
1685 | if (nbins_input != nullptr) { |
1686 | int64_t nbins; |
1687 | TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins)); |
1688 | // nbins has to be positive. |
1689 | if (nbins <= 0) { |
1690 | return errors::InvalidArgument("Requires nbins > 0: " , nbins); |
1691 | } |
1692 | c->set_output(0, c->Vector(nbins)); |
1693 | } else { |
1694 | c->set_output(0, c->UnknownShapeOfRank(1)); |
1695 | } |
1696 | return OkStatus(); |
1697 | }); |
1698 | |
1699 | REGISTER_OP("Bincount" ) |
1700 | .Input("arr: int32" ) |
1701 | .Input("size: int32" ) |
1702 | .Input("weights: T" ) |
1703 | .Attr("T: {int32, int64, float32, float64}" ) |
1704 | .Output("bins: T" ) |
1705 | .SetShapeFn([](InferenceContext* c) { |
1706 | ShapeHandle unused; |
1707 | // The input `size` must be a scalar. |
1708 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1709 | |
1710 | const Tensor* size_tensor = c->input_tensor(1); |
1711 | if (size_tensor == nullptr) { |
1712 | // Return unknown shape if size is not known. |
1713 | c->set_output(0, c->UnknownShapeOfRank(1)); |
1714 | return OkStatus(); |
1715 | } |
1716 | |
1717 | if (size_tensor->dims() != 0) { |
1718 | return errors::InvalidArgument("Shape must be rank 0 but is rank " , |
1719 | size_tensor->dims()); |
1720 | } |
1721 | |
1722 | // Return `[size]` shape if size is known. |
1723 | int32_t size_val = size_tensor->scalar<int32>()(); |
1724 | if (size_val < 0) { |
1725 | return errors::InvalidArgument("size (" , size_val, |
1726 | ") must be non-negative" ); |
1727 | } |
1728 | c->set_output(0, c->MakeShape({size_val})); |
1729 | return OkStatus(); |
1730 | }); |
1731 | |
1732 | REGISTER_OP("DenseBincount" ) |
1733 | .Input("input: Tidx" ) |
1734 | .Input("size: Tidx" ) |
1735 | .Input("weights: T" ) |
1736 | .Attr("Tidx: {int32, int64}" ) |
1737 | .Attr("T: {int32, int64, float32, float64}" ) |
1738 | .Attr("binary_output: bool = false" ) |
1739 | .Output("output: T" ) |
1740 | .SetShapeFn([](InferenceContext* c) { |
1741 | ShapeHandle unused; |
1742 | // The input `input` must be at most matrix. |
1743 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 2, &unused)); |
1744 | // The input `size` must be a scalar. |
1745 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1746 | |
1747 | const Tensor* size_tensor = c->input_tensor(1); |
1748 | if (size_tensor == nullptr) { |
1749 | // Return unknown shape if size is not known. |
1750 | c->set_output(0, c->UnknownShape()); |
1751 | return OkStatus(); |
1752 | } |
1753 | if (size_tensor->dims() != 0) { |
1754 | return errors::InvalidArgument("Shape must be rank 0 but is rank " , |
1755 | size_tensor->dims()); |
1756 | } |
1757 | |
1758 | int64_t size_val; |
1759 | DataType dtype; |
1760 | TF_RETURN_IF_ERROR(c->GetAttr("Tidx" , &dtype)); |
1761 | if (dtype == DT_INT32) { |
1762 | size_val = static_cast<int64_t>(size_tensor->scalar<int32>()()); |
1763 | } else if (dtype == DT_INT64) { |
1764 | size_val = size_tensor->scalar<int64_t>()(); |
1765 | } else { |
1766 | return errors::InvalidArgument("size dtype must be int32 or int64" ); |
1767 | } |
1768 | // Return `[size]` shape if size is known. |
1769 | if (size_val < 0) { |
1770 | return errors::InvalidArgument("size (" , size_val, |
1771 | ") must be non-negative" ); |
1772 | } |
1773 | if (c->Rank(c->input(0)) == 1) { |
1774 | c->set_output(0, c->MakeShape({size_val})); |
1775 | } else if (c->Rank(c->input(0)) == 2) { |
1776 | c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val})); |
1777 | } |
1778 | return OkStatus(); |
1779 | }); |
1780 | |
1781 | REGISTER_OP("SparseBincount" ) |
1782 | .Input("indices: int64" ) |
1783 | .Input("values: Tidx" ) |
1784 | .Input("dense_shape: int64" ) |
1785 | .Input("size: Tidx" ) |
1786 | .Input("weights: T" ) |
1787 | .Attr("Tidx: {int32, int64}" ) |
1788 | .Attr("T: {int32, int64, float32, float64}" ) |
1789 | .Attr("binary_output: bool = false" ) |
1790 | .Output("output: T" ) |
1791 | .SetShapeFn([](InferenceContext* c) { |
1792 | const Tensor* size_tensor = c->input_tensor(3); |
1793 | if (size_tensor == nullptr) { |
1794 | // Return unknown shape if size is not known. |
1795 | c->set_output(0, c->UnknownShape()); |
1796 | return OkStatus(); |
1797 | } |
1798 | if (size_tensor->dims() != 0) { |
1799 | return errors::InvalidArgument("Shape must be rank 0 but is rank " , |
1800 | size_tensor->dims()); |
1801 | } |
1802 | |
1803 | int64_t size_val; |
1804 | DataType dtype; |
1805 | TF_RETURN_IF_ERROR(c->GetAttr("Tidx" , &dtype)); |
1806 | if (dtype == DT_INT32) { |
1807 | size_val = static_cast<int64_t>(size_tensor->scalar<int32>()()); |
1808 | } else if (dtype == DT_INT64) { |
1809 | size_val = size_tensor->scalar<int64_t>()(); |
1810 | } else { |
1811 | return errors::InvalidArgument("size dtype must be int32 or int64" ); |
1812 | } |
1813 | // Return `[size]` shape if size is known. |
1814 | if (size_val < 0) { |
1815 | return errors::InvalidArgument("size (" , size_val, |
1816 | ") must be non-negative" ); |
1817 | } |
1818 | |
1819 | const Tensor* shape_tensor = c->input_tensor(2); |
1820 | if (shape_tensor == nullptr) { |
1821 | // Return unknown shape if size is not known. |
1822 | c->set_output(0, c->UnknownShape()); |
1823 | return OkStatus(); |
1824 | } |
1825 | if (shape_tensor->NumElements() == 1) { |
1826 | c->set_output(0, c->MakeShape({size_val})); |
1827 | } else if (shape_tensor->NumElements() == 2) { |
1828 | c->set_output( |
1829 | 0, c->MakeShape({shape_tensor->flat<int64_t>()(0), size_val})); |
1830 | } else { |
1831 | return errors::InvalidArgument("Input must be less than rank 2" ); |
1832 | } |
1833 | return OkStatus(); |
1834 | }); |
1835 | |
1836 | REGISTER_OP("RaggedBincount" ) |
1837 | .Input("splits: int64" ) |
1838 | .Input("values: Tidx" ) |
1839 | .Input("size: Tidx" ) |
1840 | .Input("weights: T" ) |
1841 | .Attr("Tidx: {int32, int64}" ) |
1842 | .Attr("T: {int32, int64, float32, float64}" ) |
1843 | .Attr("binary_output: bool = false" ) |
1844 | .Output("output: T" ) |
1845 | .SetShapeFn([](InferenceContext* c) { |
1846 | c->set_output(0, c->UnknownShape()); |
1847 | return OkStatus(); |
1848 | }); |
1849 | |
1850 | REGISTER_OP("Cumsum" ) |
1851 | .Input("x: T" ) |
1852 | .Input("axis: Tidx" ) |
1853 | .Attr("exclusive: bool = false" ) |
1854 | .Attr("reverse: bool = false" ) |
1855 | .Output("out: T" ) |
1856 | .Attr("T: numbertype" ) |
1857 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1858 | .SetShapeFn(shape_inference::UnchangedShape); |
1859 | |
1860 | REGISTER_OP("Cumprod" ) |
1861 | .Input("x: T" ) |
1862 | .Input("axis: Tidx" ) |
1863 | .Attr("exclusive: bool = false" ) |
1864 | .Attr("reverse: bool = false" ) |
1865 | .Output("out: T" ) |
1866 | .Attr("T: numbertype" ) |
1867 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1868 | .SetShapeFn(shape_inference::UnchangedShape); |
1869 | |
1870 | REGISTER_OP("CumulativeLogsumexp" ) |
1871 | .Input("x : T" ) |
1872 | .Input("axis: Tidx" ) |
1873 | .Attr("exclusive: bool = false" ) |
1874 | .Attr("reverse: bool = false" ) |
1875 | .Output("out: T" ) |
1876 | .Attr("T: {float16, float32, float64}" ) |
1877 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
1878 | .SetShapeFn(shape_inference::UnchangedShape); |
1879 | |
1880 | REGISTER_OP("QuantizedMatMul" ) |
1881 | .Input("a: T1" ) |
1882 | .Input("b: T2" ) |
1883 | .Input("min_a: float" ) |
1884 | .Input("max_a: float" ) |
1885 | .Input("min_b: float" ) |
1886 | .Input("max_b: float" ) |
1887 | .Output("out: Toutput" ) |
1888 | .Output("min_out: float" ) |
1889 | .Output("max_out: float" ) |
1890 | .Attr("T1: quantizedtype" ) |
1891 | .Attr("T2: quantizedtype" ) |
1892 | .Attr("Toutput: quantizedtype = DT_QINT32" ) |
1893 | .Attr("transpose_a: bool = false" ) |
1894 | .Attr("transpose_b: bool = false" ) |
1895 | .Attr("Tactivation: quantizedtype = DT_QUINT8" ) |
1896 | .SetShapeFn([](InferenceContext* c) { |
1897 | TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c)); |
1898 | ShapeHandle unused; |
1899 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1900 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1901 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1902 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
1903 | |
1904 | c->set_output(1, c->Scalar()); |
1905 | c->set_output(2, c->Scalar()); |
1906 | return OkStatus(); |
1907 | }); |
1908 | |
1909 | // Note: This op is not commutative w.r.t. to all its inputs. |
1910 | REGISTER_OP("QuantizedMul" ) |
1911 | .Input("x: T1" ) |
1912 | .Input("y: T2" ) |
1913 | .Input("min_x: float" ) |
1914 | .Input("max_x: float" ) |
1915 | .Input("min_y: float" ) |
1916 | .Input("max_y: float" ) |
1917 | .Output("z: Toutput" ) |
1918 | .Output("min_z: float" ) |
1919 | .Output("max_z: float" ) |
1920 | .Attr("T1: quantizedtype" ) |
1921 | .Attr("T2: quantizedtype" ) |
1922 | .Attr("Toutput: quantizedtype = DT_QINT32" ) |
1923 | .SetShapeFn([](InferenceContext* c) { |
1924 | TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); |
1925 | c->set_output(1, c->Scalar()); |
1926 | c->set_output(2, c->Scalar()); |
1927 | return OkStatus(); |
1928 | }); |
1929 | |
1930 | // Note: This op is not commutative w.r.t. to all its inputs. |
1931 | REGISTER_OP("QuantizedAdd" ) |
1932 | .Input("x: T1" ) |
1933 | .Input("y: T2" ) |
1934 | .Input("min_x: float" ) |
1935 | .Input("max_x: float" ) |
1936 | .Input("min_y: float" ) |
1937 | .Input("max_y: float" ) |
1938 | .Output("z: Toutput" ) |
1939 | .Output("min_z: float" ) |
1940 | .Output("max_z: float" ) |
1941 | .Attr("T1: quantizedtype" ) |
1942 | .Attr("T2: quantizedtype" ) |
1943 | .Attr("Toutput: quantizedtype = DT_QINT32" ) |
1944 | .SetShapeFn([](InferenceContext* c) { |
1945 | TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c)); |
1946 | // min_x, max_x, min_y, max_y should be scalar. |
1947 | ShapeHandle unused; |
1948 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1949 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1950 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1951 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); |
1952 | |
1953 | c->set_output(1, c->Scalar()); |
1954 | c->set_output(2, c->Scalar()); |
1955 | return OkStatus(); |
1956 | }); |
1957 | |
1958 | REGISTER_OP("QuantizeDownAndShrinkRange" ) |
1959 | .Input("input: Tinput" ) |
1960 | .Input("input_min: float" ) |
1961 | .Input("input_max: float" ) |
1962 | .Output("output: out_type" ) |
1963 | .Output("output_min: float" ) |
1964 | .Output("output_max: float" ) |
1965 | .Attr("Tinput: quantizedtype" ) |
1966 | .Attr("out_type: quantizedtype" ) |
1967 | .SetShapeFn([](InferenceContext* c) { |
1968 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
1969 | ShapeHandle unused; |
1970 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1971 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1972 | c->set_output(1, c->Scalar()); |
1973 | c->set_output(2, c->Scalar()); |
1974 | return OkStatus(); |
1975 | }); |
1976 | |
1977 | REGISTER_OP("Requantize" ) |
1978 | .Input("input: Tinput" ) |
1979 | .Input("input_min: float" ) |
1980 | .Input("input_max: float" ) |
1981 | .Input("requested_output_min: float" ) |
1982 | .Input("requested_output_max: float" ) |
1983 | .Output("output: out_type" ) |
1984 | .Output("output_min: float" ) |
1985 | .Output("output_max: float" ) |
1986 | .Attr("Tinput: quantizedtype" ) |
1987 | .Attr("out_type: quantizedtype" ) |
1988 | .SetShapeFn([](InferenceContext* c) { |
1989 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
1990 | ShapeHandle unused; |
1991 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
1992 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
1993 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
1994 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
1995 | c->set_output(1, c->Scalar()); |
1996 | c->set_output(2, c->Scalar()); |
1997 | return OkStatus(); |
1998 | }); |
1999 | |
2000 | REGISTER_OP("RequantizationRange" ) |
2001 | .Input("input: Tinput" ) |
2002 | .Input("input_min: float" ) |
2003 | .Input("input_max: float" ) |
2004 | .Output("output_min: float" ) |
2005 | .Output("output_max: float" ) |
2006 | .Attr("Tinput: quantizedtype" ) |
2007 | .SetShapeFn([](InferenceContext* c) { |
2008 | ShapeHandle unused; |
2009 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
2010 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
2011 | c->set_output(0, c->Scalar()); |
2012 | c->set_output(1, c->Scalar()); |
2013 | return OkStatus(); |
2014 | }); |
2015 | |
2016 | // -------------------------------------------------------------------------- |
2017 | |
2018 | REGISTER_OP("Bucketize" ) |
2019 | .Input("input: T" ) |
2020 | .Output("output: int32" ) |
2021 | .Attr("T: {int32, int64, float, double}" ) |
2022 | .Attr("boundaries: list(float)" ) |
2023 | .SetShapeFn(shape_inference::UnchangedShape); |
2024 | |
2025 | REGISTER_OP("ClipByValue" ) |
2026 | .Input("t: T" ) |
2027 | .Input("clip_value_min: T" ) |
2028 | .Input("clip_value_max: T" ) |
2029 | .Output("output: T" ) |
2030 | .Attr("T: numbertype" ) |
2031 | .SetShapeFn(shape_inference::UnchangedShape); |
2032 | |
2033 | #ifdef INTEL_MKL |
2034 | // Note: This op is not commutative w.r.t. to all its inputs. |
2035 | REGISTER_OP("_MklAddN" ) |
2036 | .Input("inputs: N * T" ) |
2037 | .Input("mkl_input: N * uint8" ) |
2038 | .Output("sum: T" ) |
2039 | .Output("mkl_sum: uint8" ) |
2040 | .Attr("N: int >= 1" ) |
2041 | .Attr("T: numbertype" ) |
2042 | .SetShapeFn([](InferenceContext* c) { |
2043 | ShapeHandle cur = c->input(c->num_inputs() - 1); |
2044 | for (int i = c->num_inputs() - 2; i >= 0; --i) { |
2045 | TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur), |
2046 | "From merging shape " , i, |
2047 | " with other shapes." ); |
2048 | } |
2049 | c->set_output(0, cur); |
2050 | return Status::OK(); |
2051 | }) |
2052 | .Doc(R"doc( |
2053 | Add two input tensors element wise using mkl kernel sum. |
2054 | inputs: Must all be the same size and shape. |
2055 | )doc" ); |
2056 | |
2057 | #endif // INTEL_MKL |
2058 | |
2059 | REGISTER_OP("RequantizePerChannel" ) |
2060 | .Input("input: T" ) |
2061 | .Input("input_min: float" ) |
2062 | .Input("input_max: float" ) |
2063 | .Input("requested_output_min: float" ) |
2064 | .Input("requested_output_max: float" ) |
2065 | .Output("output: out_type" ) |
2066 | .Output("output_min: float" ) |
2067 | .Output("output_max: float" ) |
2068 | .Attr("T: quantizedtype = DT_QINT32" ) |
2069 | .Attr("out_type: quantizedtype = DT_QUINT8" ) |
2070 | .SetShapeFn([](InferenceContext* c) { |
2071 | TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c)); |
2072 | ShapeHandle unused; |
2073 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
2074 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
2075 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); |
2076 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); |
2077 | c->set_output(1, c->Scalar()); |
2078 | c->set_output(2, c->Scalar()); |
2079 | return OkStatus(); |
2080 | }); |
2081 | REGISTER_OP("RequantizationRangePerChannel" ) |
2082 | .Input("input: T" ) |
2083 | .Input("input_min: float" ) |
2084 | .Input("input_max: float" ) |
2085 | .Output("output_min: float" ) |
2086 | .Output("output_max: float" ) |
2087 | .Attr("T: quantizedtype = DT_QINT32" ) |
2088 | .Attr("clip_value_max: float" ) |
2089 | .SetShapeFn([](InferenceContext* c) { |
2090 | ShapeHandle unused; |
2091 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
2092 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
2093 | c->set_output(0, c->Scalar()); |
2094 | c->set_output(1, c->Scalar()); |
2095 | return OkStatus(); |
2096 | }); |
2097 | |
2098 | REGISTER_OP("NextAfter" ) |
2099 | .Attr("T: {float64, float32} = DT_FLOAT" ) |
2100 | .Input("x1: T" ) |
2101 | .Input("x2: T" ) |
2102 | .Output("output: T" ) |
2103 | .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); |
2104 | |
2105 | REGISTER_OP("SobolSample" ) |
2106 | .Input("dim: int32" ) |
2107 | .Input("num_results: int32" ) |
2108 | .Input("skip: int32" ) |
2109 | .Attr("dtype: {float, double} = DT_FLOAT" ) |
2110 | .Output("samples: dtype" ) |
2111 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
2112 | ShapeHandle unused; |
2113 | |
2114 | // inputs must be scalars |
2115 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
2116 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
2117 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
2118 | |
2119 | const Tensor* dim_t = c->input_tensor(0); |
2120 | const Tensor* num_results_t = c->input_tensor(1); |
2121 | |
2122 | int32_t dim = dim_t == nullptr ? InferenceContext::kUnknownDim |
2123 | : dim_t->scalar<int32>()(); |
2124 | |
2125 | int32_t num_results = num_results_t == nullptr |
2126 | ? InferenceContext::kUnknownDim |
2127 | : num_results_t->scalar<int32>()(); |
2128 | |
2129 | c->set_output(0, c->Matrix(num_results, dim)); |
2130 | return OkStatus(); |
2131 | }); |
2132 | |
2133 | } // namespace tensorflow |
2134 | |