1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/numeric_op.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/shape_inference.h"
20
21namespace tensorflow {
22
23using shape_inference::DimensionHandle;
24using shape_inference::InferenceContext;
25using shape_inference::ShapeHandle;
26
27REGISTER_OP("AddN")
28 .Input("inputs: N * T")
29 .Output("sum: T")
30 .Attr("N: int >= 1")
31 .Attr("T: {numbertype, variant}")
32 .SetIsCommutative()
33 .SetIsAggregate()
34 .SetShapeFn([](InferenceContext* c) {
35 ShapeHandle cur = c->input(c->num_inputs() - 1);
36 for (int i = c->num_inputs() - 2; i >= 0; --i) {
37 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
38 "From merging shape ", i,
39 " with other shapes.");
40 }
41 c->set_output(0, cur);
42
43 DataType dtype;
44 TF_RETURN_IF_ERROR(c->GetAttr("T", &dtype));
45
46 if (dtype != DT_VARIANT) {
47 // Exit early if not DT_VARIANT.
48 return OkStatus();
49 } else {
50 // DT_VARIANT shape handle shape inference. All sizes and dtypes must
51 // be the same; all shapes must be compatible via Merge.
52 std::vector<shape_inference::ShapeAndType> cur_shapes_and_types;
53 auto* shapes_and_types =
54 c->input_handle_shapes_and_types(c->num_inputs() - 1);
55 if (shapes_and_types) {
56 cur_shapes_and_types = *shapes_and_types;
57 }
58
59 for (int i = c->num_inputs() - 2; i >= 0; --i) {
60 auto shapes_and_types_i = c->input_handle_shapes_and_types(i);
61 if (!shapes_and_types && shapes_and_types_i) {
62 // TODO(ebrevdo): Find cases where this happens and fix their shape
63 // inference. If we are calling AddN on variant types, they should
64 // all have consistent shape_and_type info.
65 shapes_and_types = shapes_and_types_i;
66 } else if (shapes_and_types && shapes_and_types_i) {
67 if (shapes_and_types_i->size() != shapes_and_types->size()) {
68 return errors::InvalidArgument(
69 "shapes_and_types[", i,
70 "].size() == ", shapes_and_types_i->size(),
71 " != shapes_and_types[0].size() == ",
72 shapes_and_types->size());
73 }
74 for (int j = 0; j < shapes_and_types->size(); ++j) {
75 if (shapes_and_types->at(j).dtype !=
76 shapes_and_types_i->at(j).dtype) {
77 return errors::InvalidArgument(
78 "shapes_and_types[", i, "][", j, "].dtype() == ",
79 DataTypeString(shapes_and_types_i->at(j).dtype),
80 " != shapes_and_types[0][", j, "].dtype == ",
81 DataTypeString(shapes_and_types->at(j).dtype));
82 }
83 TF_RETURN_WITH_CONTEXT_IF_ERROR(
84 c->Merge(shapes_and_types_i->at(j).shape,
85 cur_shapes_and_types.at(j).shape,
86 &cur_shapes_and_types.at(j).shape),
87 "From merging shapes_and_types[", i, "][", j, "].shape with ",
88 "shapes_and_types[0][", j, "].shape");
89 }
90 }
91 }
92 if (shapes_and_types) {
93 c->set_output_handle_shapes_and_types(0, cur_shapes_and_types);
94 }
95 return OkStatus();
96 }
97 });
98
99// --------------------------------------------------------------------------
100
101// Note that the following operator is just a placeholder and has no
102// associated kernel. The code in accumulate_n_optimizer.cc replaces
103// this placeholder with a graph of operators that do have kernels.
104// The Python code that generates instances of this op is currently in
105// contrib/framework/python/ops/accumulate_n_v2.py
106REGISTER_OP("AccumulateNV2")
107 .Input("inputs: N * T")
108 .Output("sum: T")
109 .Attr("N: int >= 1")
110 .Attr("T: numbertype")
111 .Attr("shape: shape")
112 .SetIsCommutative()
113 .SetIsAggregate()
114 .SetShapeFn(shape_inference::ExplicitShape);
115
116// --------------------------------------------------------------------------
117
118REGISTER_OP("BatchMatMul")
119 .Input("x: T")
120 .Input("y: T")
121 .Output("output: T")
122 .Attr(
123 "T: {bfloat16, half, float, double, int32, int64, complex64, "
124 "complex128}")
125 .Attr("adj_x: bool = false")
126 .Attr("adj_y: bool = false")
127 .SetShapeFn(shape_inference::BatchMatMulShape);
128
129REGISTER_OP("BatchMatMulV2")
130 .Input("x: T")
131 .Input("y: T")
132 .Output("output: T")
133 .Attr(
134 "T: {bfloat16, half, float, double, int16, int32, int64, complex64, "
135 "complex128}")
136 .Attr("adj_x: bool = false")
137 .Attr("adj_y: bool = false")
138 .SetShapeFn(shape_inference::BatchMatMulV2Shape);
139
140REGISTER_OP("BatchMatMulV3")
141 .Input("x: Ta")
142 .Input("y: Tb")
143 .Output("output: Tout")
144 .Attr(
145 "Ta: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
146 "complex64, complex128}")
147 .Attr(
148 "Tb: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
149 "complex64, complex128}")
150 .Attr(
151 "Tout: {bfloat16, half, float, double, int16, int32, int64, complex64, "
152 "complex128}")
153 .Attr("adj_x: bool = false")
154 .Attr("adj_y: bool = false")
155 .SetShapeFn(shape_inference::BatchMatMulV2Shape);
156
157#ifdef INTEL_MKL
158REGISTER_OP("_MklBatchMatMul")
159 .Input("x: T")
160 .Input("y: T")
161 .Output("output: T")
162 .Attr("T: {bfloat16, float}")
163 .Attr("adj_x: bool = false")
164 .Attr("adj_y: bool = false")
165 .SetShapeFn(shape_inference::BatchMatMulShape);
166
167REGISTER_OP("_MklBatchMatMulV2")
168 .Input("x: T")
169 .Input("y: T")
170 .Output("output: T")
171 .Attr("T: {bfloat16, float}")
172 .Attr("adj_x: bool = false")
173 .Attr("adj_y: bool = false")
174 .SetShapeFn(shape_inference::BatchMatMulV2Shape);
175#endif // INTEL_MKL
176
177// --------------------------------------------------------------------------
178// Casting Ops
179//
180// NOTE: Only a smaller number of types are supported by
181// Cast. The exact casting rule is TBD. The current
182// implementation uses C++ static cast rules for numeric
183// types, which may be changed in the future.
184REGISTER_OP("Cast")
185 .Input("x: SrcT")
186 .Output("y: DstT")
187 .Attr("SrcT: type")
188 .Attr("DstT: type")
189 .Attr("Truncate: bool = false")
190 .SetTypeConstructor(full_type::NoOp())
191 .SetForwardTypeFn(full_type::KeepExisting())
192 .SetShapeFn(shape_inference::UnchangedShape);
193
194REGISTER_OP("_HostCast")
195 .Input("x: SrcT")
196 .Output("y: DstT")
197 .Attr("SrcT: type")
198 .Attr("DstT: type")
199 .Attr("Truncate: bool = false")
200 .SetTypeConstructor(full_type::NoOp())
201 .SetForwardTypeFn(full_type::KeepExisting())
202 .SetShapeFn(shape_inference::UnchangedShape)
203 .Doc(R"doc(
204Cast x of type SrcT to y of DstT.
205
206_HostCast requires its input and produces its output in host memory.
207)doc");
208
209// --------------------------------------------------------------------------
210
211REGISTER_OP("Abs")
212 .Input("x: T")
213 .Output("y: T")
214 .Attr("T: {bfloat16, half, float, double, int8, int16, int32, int64}")
215 .SetShapeFn(shape_inference::UnchangedShape);
216
217REGISTER_OP("ComplexAbs")
218 .Input("x: T")
219 .Output("y: Tout")
220 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
221 .Attr("Tout: {float, double} = DT_FLOAT")
222 .SetShapeFn(shape_inference::UnchangedShape);
223
224// Declares cwise unary operations signature: 't -> 't
225#define UNARY() \
226 Input("x: T") \
227 .Output("y: T") \
228 .Attr( \
229 "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
230 "complex64, complex128}") \
231 .SetShapeFn(shape_inference::UnchangedShape)
232
233#define UNARY_UNSIGNED() \
234 Input("x: T") \
235 .Output("y: T") \
236 .Attr( \
237 "T: {bfloat16, half, float, double, int8, int16, int32, int64, " \
238 "uint8, uint16, uint32, uint64, complex64, complex128}") \
239 .SetShapeFn(shape_inference::UnchangedShape)
240
241#define UNARY_REAL() \
242 Input("x: T") \
243 .Output("y: T") \
244 .Attr("T: {bfloat16, half, float, double}") \
245 .SetShapeFn(shape_inference::UnchangedShape)
246
247#define UNARY_COMPLEX() \
248 Input("x: T") \
249 .Output("y: T") \
250 .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
251 .SetShapeFn(shape_inference::UnchangedShape)
252
253#define UNARY_GRADIENT_COMPLEX() \
254 Input("y: T") \
255 .Input("dy: T") \
256 .Output("z: T") \
257 .Attr("T: {bfloat16, half, float, double, complex64, complex128}") \
258 .SetShapeFn(shape_inference::UnchangedShape)
259
260REGISTER_OP("Neg").UNARY();
261
262REGISTER_OP("Inv").UNARY();
263
264REGISTER_OP("InvGrad").UNARY_GRADIENT_COMPLEX();
265
266REGISTER_OP("Reciprocal").UNARY();
267
268REGISTER_OP("ReciprocalGrad").UNARY_GRADIENT_COMPLEX();
269
270REGISTER_OP("Square").UNARY_UNSIGNED();
271
272REGISTER_OP("Sqrt").UNARY_COMPLEX();
273
274REGISTER_OP("SqrtGrad").UNARY_GRADIENT_COMPLEX();
275
276REGISTER_OP("Rsqrt").UNARY_COMPLEX();
277
278REGISTER_OP("Round").UNARY();
279
280REGISTER_OP("RsqrtGrad").UNARY_GRADIENT_COMPLEX();
281
282REGISTER_OP("Exp").UNARY_COMPLEX();
283
284REGISTER_OP("Expm1").UNARY_COMPLEX();
285
286REGISTER_OP("Log").UNARY_COMPLEX();
287
288REGISTER_OP("Log1p").UNARY_COMPLEX();
289
290REGISTER_OP("Sinh").UNARY_COMPLEX();
291
292REGISTER_OP("Cosh").UNARY_COMPLEX();
293
294REGISTER_OP("Tanh").UNARY_COMPLEX();
295
296REGISTER_OP("Asinh").UNARY_COMPLEX();
297
298REGISTER_OP("Acosh").UNARY_COMPLEX();
299
300REGISTER_OP("Atanh").UNARY_COMPLEX();
301
302REGISTER_OP("TanhGrad").UNARY_GRADIENT_COMPLEX();
303
304REGISTER_OP("Lgamma").UNARY_REAL();
305
306REGISTER_OP("Digamma").UNARY_REAL();
307
308REGISTER_OP("Erf").UNARY_REAL();
309REGISTER_OP("Erfinv").UNARY_REAL();
310REGISTER_OP("Ndtri").UNARY_REAL();
311REGISTER_OP("Erfc").UNARY_REAL();
312
313REGISTER_OP("Sigmoid").UNARY_COMPLEX();
314
315REGISTER_OP("SigmoidGrad").UNARY_GRADIENT_COMPLEX();
316
317REGISTER_OP("Sin").UNARY_COMPLEX();
318
319REGISTER_OP("Cos").UNARY_COMPLEX();
320
321REGISTER_OP("Tan").UNARY();
322
323REGISTER_OP("Asin").UNARY();
324
325REGISTER_OP("Acos").UNARY();
326
327REGISTER_OP("Atan").UNARY();
328
329REGISTER_OP("_UnaryOpsComposition")
330 .Input("x: T")
331 .Output("y: T")
332 .Attr("T: {float, half, double}")
333 .Attr("op_names: list(string)")
334 .SetShapeFn(shape_inference::UnchangedShape)
335 .Doc(R"doc(
336*NOTE*: Do not invoke this operator directly in Python. Graph rewrite pass is
337expected to create these operators.
338)doc");
339
340#undef UNARY
341#undef UNARY_REAL
342#undef UNARY_COMPLEX
343
344REGISTER_OP("IsNan")
345 .Input("x: T")
346 .Output("y: bool")
347 .Attr("T: {bfloat16, half, float, double}")
348 .SetShapeFn(shape_inference::UnchangedShape);
349
350REGISTER_OP("IsInf")
351 .Input("x: T")
352 .Output("y: bool")
353 .Attr("T: {bfloat16, half, float, double}")
354 .SetShapeFn(shape_inference::UnchangedShape);
355
356REGISTER_OP("IsFinite")
357 .Input("x: T")
358 .Output("y: bool")
359 .Attr("T: {bfloat16, half, float, double}")
360 .SetShapeFn(shape_inference::UnchangedShape);
361
362REGISTER_OP("Sign")
363 .Input("x: T")
364 .Output("y: T")
365 .Attr(
366 "T: {bfloat16, half, float, double, int8, int16, int32, int64, "
367 "complex64, complex128}")
368 .SetShapeFn(shape_inference::UnchangedShape);
369
370REGISTER_OP("Floor")
371 .Input("x: T")
372 .Output("y: T")
373 .Attr("T: {bfloat16, half, float, double}")
374 .SetShapeFn(shape_inference::UnchangedShape);
375
376REGISTER_OP("Ceil")
377 .Input("x: T")
378 .Output("y: T")
379 .Attr("T: {bfloat16, half, float, double}")
380 .SetShapeFn(shape_inference::UnchangedShape);
381
382REGISTER_OP("Rint")
383 .Input("x: T")
384 .Output("y: T")
385 .Attr("T: {bfloat16, half, float, double}")
386 .SetShapeFn(shape_inference::UnchangedShape);
387
388// Declares cwise binary operations signature: 't, 't -> 't.
389
390#define BINARY_MORE() \
391 Input("x: T").Input("y: T").Output("z: T").Attr( \
392 "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, " \
393 "uint32, uint64, int64, complex64, complex128}")
394
395#define BINARY_FEWER() \
396 Input("x: T").Input("y: T").Output("z: T").Attr( \
397 "T: {bfloat16, half, float, double, int32, int64, complex64, " \
398 "complex128}")
399
400REGISTER_OP("Add")
401 .Input("x: T")
402 .Input("y: T")
403 .Output("z: T")
404 .Attr(
405 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
406 "complex64, complex128, string}")
407 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
408
409REGISTER_OP("AddV2")
410 .Input("x: T")
411 .Input("y: T")
412 .Output("z: T")
413 .Attr(
414 "T: {bfloat16, half, float, double, uint8, uint16, uint32, uint64, "
415 "int8, int16, int32, int64, complex64, complex128}")
416 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
417 .SetIsAggregate()
418 .SetIsCommutative();
419
420#ifdef INTEL_MKL
421REGISTER_OP("_MklAdd")
422 .Input("x: T")
423 .Input("y: T")
424 .Input("mkl_x: uint8")
425 .Input("mkl_y: uint8")
426 .Output("z: T")
427 .Output("mkl_z: uint8")
428 .Attr(
429 "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
430 "complex128, string, bfloat16}")
431 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
432 .Doc(R"doc(
433Returns `x` + `y` element-wise.
434
435*NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
436[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
437)doc");
438
439REGISTER_OP("_MklAddV2")
440 .Input("x: T")
441 .Input("y: T")
442 .Input("mkl_x: uint8")
443 .Input("mkl_y: uint8")
444 .Output("z: T")
445 .Output("mkl_z: uint8")
446 .Attr(
447 "T: {bfloat16, half, float, double, uint8, int8, int16, int32, int64, "
448 "complex64, complex128}")
449 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
450 .SetIsAggregate()
451 .SetIsCommutative()
452 .Doc(R"doc(
453Returns `x` + `y` element-wise.
454*NOTE*: `tf.math.add` supports broadcasting. `tf.math.add_n` does not. More about broadcasting
455[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html).
456)doc");
457#endif // INTEL_MKL
458
459REGISTER_OP("Sub")
460 .Input("x: T")
461 .Input("y: T")
462 .Output("z: T")
463 .Attr(
464 "T: {bfloat16, half, float, double, uint8, int8, uint16, int16, int32, "
465 "int64, complex64, complex128, uint32, uint64}")
466 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
467
468REGISTER_OP("_MklSub")
469 .BINARY_FEWER()
470 .Input("mkl_x: uint8")
471 .Input("mkl_y: uint8")
472 .Output("mkl_z: uint8")
473 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
474 .Doc(R"doc(
475Returns x - y element-wise.
476
477*NOTE*: `Sub` supports broadcasting. More about broadcasting
478[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
479)doc");
480
481REGISTER_OP("Mul").BINARY_MORE().SetIsCommutative().SetShapeFn(
482 shape_inference::BroadcastBinaryOpShapeFn);
483
484REGISTER_OP("MulNoNan")
485 .Input("x: T")
486 .Input("y: T")
487 .Output("z: T")
488 .Attr("T: {bfloat16, half, float, double, complex64, complex128}")
489 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
490
491// Note: This op is not commutative w.r.t. to all its inputs.
492REGISTER_OP("_MklMul")
493 .BINARY_MORE()
494 .Input("mkl_x: uint8")
495 .Input("mkl_y: uint8")
496 .Output("mkl_z: uint8")
497 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
498 .Doc(R"doc(
499Returns x * y element-wise.
500
501*NOTE*: `Mul` supports broadcasting. More about broadcasting
502[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
503)doc");
504
505REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
506 shape_inference::BroadcastBinaryOpShapeFn);
507
508REGISTER_OP("DivNoNan")
509 .Input("x: T")
510 .Input("y: T")
511 .Output("z: T")
512 .Attr("T: {half, float, bfloat16, double, complex64, complex128}")
513 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
514
515REGISTER_OP("FloorDiv")
516 .BINARY_MORE()
517 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
518
519REGISTER_OP("TruncateDiv")
520 .BINARY_MORE()
521 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
522
523REGISTER_OP("RealDiv").BINARY_MORE().SetShapeFn(
524 shape_inference::BroadcastBinaryOpShapeFn);
525
526// Note SquaredDifference implements conj(x - y)*(x - y).
527REGISTER_OP("SquaredDifference")
528 .BINARY_FEWER()
529 .SetIsCommutative()
530 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
531
532// Note: This op is not commutative w.r.t. to all its inputs.
533REGISTER_OP("_MklSquaredDifference")
534 .BINARY_FEWER()
535 .Input("mkl_x: uint8")
536 .Input("mkl_y: uint8")
537 .Output("mkl_z: uint8")
538 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
539 .Doc(R"doc(
540Returns (x - y)(x - y) element-wise.
541
542*NOTE*: `SquaredDifference` supports broadcasting. More about broadcasting
543[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
544)doc");
545
546REGISTER_OP("Xlogy")
547 .Input("x: T")
548 .Input("y: T")
549 .Output("z: T")
550 .Attr("T: {half, float, double, complex64, complex128}")
551 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
552
553REGISTER_OP("Xlog1py")
554 .Input("x: T")
555 .Input("y: T")
556 .Output("z: T")
557 .Attr("T: {half, float, double, complex64, complex128}")
558 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
559
560REGISTER_OP("Xdivy")
561 .Input("x: T")
562 .Input("y: T")
563 .Output("z: T")
564 .Attr("T: {half, float, double, complex64, complex128}")
565 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
566
567#undef BINARY_FEWER
568#undef BINARY_MORE
569
570REGISTER_OP("Maximum")
571 .Input("x: T")
572 .Input("y: T")
573 .Output("z: T")
574 .Attr(
575 "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, "
576 "int32, uint32, int64, uint64}")
577 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
578
579// Note: This op is not commutative w.r.t. to all its inputs.
580REGISTER_OP("_MklMaximum")
581 .Input("x: T")
582 .Input("y: T")
583 .Input("mkl_x: uint8")
584 .Input("mkl_y: uint8")
585 .Output("z: T")
586 .Output("mkl_z: uint8")
587 .Attr("T: {half, float, double, int32, int64, bfloat16}")
588 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
589 .Doc(R"doc(
590Returns the max of x and y (i.e. x > y ? x : y) element-wise.
591
592*NOTE*: `Maximum` supports broadcasting. More about broadcasting
593[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
594)doc");
595
596REGISTER_OP("Minimum")
597 .Input("x: T")
598 .Input("y: T")
599 .Output("z: T")
600 .Attr(
601 "T: {bfloat16, half, float, double, int8, uint8, int16, uint16, "
602 "int32, uint32, int64, uint64}")
603 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
604
605REGISTER_OP("Mod")
606 .Input("x: T")
607 .Input("y: T")
608 .Output("z: T")
609 .Attr("T: {int32, int64, float16, half, bfloat16, float, double}")
610 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
611
612REGISTER_OP("FloorMod")
613 .Input("x: T")
614 .Input("y: T")
615 .Output("z: T")
616 .Attr(
617 "T: {int8, int16, int32, int64, uint8, uint16, uint32, uint64, "
618 "bfloat16, half, float, double}")
619 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
620
621REGISTER_OP("TruncateMod")
622 .Input("x: T")
623 .Input("y: T")
624 .Output("z: T")
625 .Attr("T: {int32, int64, bfloat16, half, float, double}")
626 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
627
628REGISTER_OP("Pow")
629 .Input("x: T")
630 .Input("y: T")
631 .Output("z: T")
632 .Attr(
633 "T: {bfloat16, float, half, double, int8, int16, int32, int64, "
634 "complex64, complex128}")
635 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
636
637REGISTER_OP("Igammac")
638 .Input("a: T")
639 .Input("x: T")
640 .Output("z: T")
641 .Attr("T: {bfloat16, half, float, double}")
642 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
643
644REGISTER_OP("Igamma")
645 .Input("a: T")
646 .Input("x: T")
647 .Output("z: T")
648 .Attr("T: {bfloat16, half, float, double}")
649 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
650
651REGISTER_OP("IgammaGradA")
652 .Input("a: T")
653 .Input("x: T")
654 .Output("z: T")
655 .Attr("T: {float, double}")
656 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
657
658REGISTER_OP("Zeta")
659 .Input("x: T")
660 .Input("q: T")
661 .Output("z: T")
662 .Attr("T: {float, double}")
663 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
664
665REGISTER_OP("Polygamma")
666 .Input("a: T")
667 .Input("x: T")
668 .Output("z: T")
669 .Attr("T: {float, double}")
670 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
671
672REGISTER_OP("Atan2")
673 .Input("y: T")
674 .Input("x: T")
675 .Output("z: T")
676 .Attr("T: {bfloat16, half, float, double}")
677 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
678
679REGISTER_OP("Betainc")
680 .Input("a: T")
681 .Input("b: T")
682 .Input("x: T")
683 .Output("z: T")
684 .Attr("T: {float, double}")
685 .SetShapeFn([](InferenceContext* c) {
686 const int num_inputs = 3;
687 ShapeHandle output = c->UnknownShape();
688 int num_scalars = 0;
689 ShapeHandle some_non_scalar;
690 for (int i = 0; i < num_inputs; ++i) {
691 ShapeHandle in = c->input(i);
692 if (!c->RankKnown(in)) {
693 some_non_scalar = in;
694 // An input with unknown rank could be either a scalar (to be
695 // broadcast) or some other shape.
696 } else if (c->Rank(in) == 0) {
697 // Input is a scalar, it will be broadcast to the output shape.
698 ++num_scalars;
699 } else {
700 TF_RETURN_IF_ERROR(c->Merge(output, in, &output));
701 some_non_scalar = output;
702 }
703 }
704
705 if (num_scalars == num_inputs - 1) {
706 // If all but one input is known to be a scalar, then output is the
707 // remaining input.
708 output = some_non_scalar;
709 } else if (num_scalars == num_inputs) {
710 // If all are scalars, output is scalar; pick the first one arbitrarily.
711 output = c->input(0);
712 }
713
714 c->set_output(0, output);
715 return OkStatus();
716 });
717
718// --------------------------------------------------------------------------
719
720// Declares cwise binary comparison operations signature: 't, 't -> bool,
721// where 't has a natural total order.
722#define COMPARISON() \
723 Input("x: T") \
724 .Input("y: T") \
725 .Output("z: bool") \
726 .Attr("T: realnumbertype") \
727 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
728
729REGISTER_OP("Less").COMPARISON();
730
731REGISTER_OP("LessEqual").COMPARISON();
732
733REGISTER_OP("Greater").COMPARISON();
734
735REGISTER_OP("GreaterEqual").COMPARISON();
736
737#undef COMPARISON
738
739// --------------------------------------------------------------------------
740
741#define EQUALITY_COMPARISON() \
742 Input("x: T") \
743 .Input("y: T") \
744 .Output("z: bool") \
745 .SetIsCommutative() \
746 .Attr("T: type") \
747 .Attr("incompatible_shape_error: bool = true") \
748 .SetShapeFn([](InferenceContext* c) { \
749 ShapeHandle x = c->input(0); \
750 ShapeHandle y = c->input(1); \
751 ShapeHandle output; \
752 bool incompatible_shape_error; \
753 TF_RETURN_IF_ERROR(c->GetAttr("incompatible_shape_error", \
754 &incompatible_shape_error)); \
755 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFnHelper( \
756 c, x, y, incompatible_shape_error, &output)); \
757 c->set_output(0, output); \
758 return OkStatus(); \
759 })
760
761REGISTER_OP("Equal").EQUALITY_COMPARISON();
762
763REGISTER_OP("NotEqual").EQUALITY_COMPARISON();
764
765#undef EQUALITY_COMPARISON
766
767REGISTER_OP("ApproximateEqual")
768 .Input("x: T")
769 .Input("y: T")
770 .Output("z: bool")
771 .SetIsCommutative()
772 .Attr("T: numbertype")
773 .Attr("tolerance: float = 0.00001")
774 .SetShapeFn([](InferenceContext* c) {
775 // The inputs 'x' and 'y' must have the same shape.
776 ShapeHandle data_x = c->input(0);
777 ShapeHandle data_y = c->input(1);
778 TF_RETURN_IF_ERROR(c->Merge(data_x, data_y, &data_x));
779 return shape_inference::UnchangedShape(c);
780 });
781
782// --------------------------------------------------------------------------
783
784REGISTER_OP("LogicalNot")
785 .Input("x: bool")
786 .Output("y: bool")
787 .SetShapeFn(shape_inference::UnchangedShape);
788
789#define BINARY_LOGICAL() \
790 Input("x: bool") \
791 .Input("y: bool") \
792 .Output("z: bool") \
793 .SetIsCommutative() \
794 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
795
796REGISTER_OP("LogicalAnd").BINARY_LOGICAL();
797
798REGISTER_OP("LogicalOr").BINARY_LOGICAL();
799
800#undef BINARY_LOGICAL
801
802// --------------------------------------------------------------------------
803
804REGISTER_OP("Select")
805 .Input("condition: bool")
806 .Input("t: T")
807 .Input("e: T")
808 .Output("output: T")
809 .Attr("T: type")
810 .SetShapeFn([](InferenceContext* c) {
811 auto* handle_data_1 = c->input_handle_shapes_and_types(1);
812 auto* handle_data_2 = c->input_handle_shapes_and_types(2);
813 // Merge handle shape and dtype if applicable.
814 if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
815 const auto size = handle_data_1->size();
816 std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
817 if (size != handle_data_2->size()) {
818 return errors::InvalidArgument(
819 "Trying to merge handles pointing to different numbers of "
820 "tensors.");
821 }
822
823 for (int i = 0; i < size; ++i) {
824 const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
825 const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
826 if (s1.dtype != s2.dtype) {
827 // TODO(apassos) resolve this in the manner of b/32476923
828 return errors::InvalidArgument(
829 "Trying to merge handles pointing to different dtypes.");
830 }
831 merged_handle_data[i].dtype = s1.dtype;
832 TF_RETURN_IF_ERROR(
833 c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
834 }
835
836 c->set_output_handle_shapes_and_types(0, merged_handle_data);
837 }
838
839 // The inputs 'then' and 'else' must have the same shape.
840 ShapeHandle data = c->input(1);
841 ShapeHandle other = c->input(2);
842 TF_RETURN_IF_ERROR(c->Merge(data, other, &data));
843
844 // The input 'cond' must either have the same shape as 'then' and
845 // 'else', or be a vector if 'then' and 'else' are at least vectors.
846 ShapeHandle cond = c->input(0);
847
848 if (!c->RankKnown(cond) || !c->RankKnown(data)) {
849 c->set_output(0, data);
850 return OkStatus();
851 }
852
853 // rank of shape and data is known.
854
855 const int32_t cond_rank = c->Rank(cond);
856 const int32_t data_rank = c->Rank(data);
857
858 if (cond_rank == 0) {
859 // The rank of 'cond' is a scalar.
860 // t and e can have any shape.
861 c->set_output(0, data);
862 return OkStatus();
863 }
864
865 if (cond_rank != 1) {
866 // If 'cond' is not a vector, and not a scalar,
867 // then shape must match 'then' and 'else'
868 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
869 c->set_output(0, data);
870 return OkStatus();
871 }
872
873 if (data_rank == 0) {
874 // if 'then' and 'else' are scalar also the cond must be
875 TF_RETURN_IF_ERROR(c->Merge(data, cond, &data));
876 c->set_output(0, data);
877 return OkStatus();
878 }
879
880 if (cond_rank == 1) {
881 // if the cond is a vector and the 'then' is not a scalar,
882 // the first dimension of 'then' and 'else'
883 TF_RETURN_IF_ERROR(c->Merge(cond, c->Vector(c->Dim(data, 0)), &cond));
884 c->set_output(0, data);
885 return OkStatus();
886 }
887
888 c->set_output(0, data);
889
890 return OkStatus();
891 });
892
893REGISTER_OP("SelectV2")
894 .Input("condition: bool")
895 .Input("t: T")
896 .Input("e: T")
897 .Output("output: T")
898 .Attr("T: type")
899 .SetShapeFn([](InferenceContext* c) {
900 auto* handle_data_1 = c->input_handle_shapes_and_types(1);
901 auto* handle_data_2 = c->input_handle_shapes_and_types(2);
902 // Merge handle shape and dtype if applicable.
903 if (handle_data_1 != nullptr && handle_data_2 != nullptr) {
904 const auto size = handle_data_1->size();
905 std::vector<shape_inference::ShapeAndType> merged_handle_data(size);
906 if (size != handle_data_2->size()) {
907 return errors::InvalidArgument(
908 "Trying to merge handles pointing to different numbers of "
909 "tensors.");
910 }
911
912 for (int i = 0; i < size; ++i) {
913 const shape_inference::ShapeAndType& s1 = (*handle_data_1)[i];
914 const shape_inference::ShapeAndType& s2 = (*handle_data_2)[i];
915 if (s1.dtype != s2.dtype) {
916 // TODO(apassos) resolve this in the manner of b/32476923
917 return errors::InvalidArgument(
918 "Trying to merge handles pointing to different dtypes.");
919 }
920 merged_handle_data[i].dtype = s1.dtype;
921 TF_RETURN_IF_ERROR(
922 c->Merge(s1.shape, s2.shape, &merged_handle_data[i].shape));
923 }
924
925 c->set_output_handle_shapes_and_types(0, merged_handle_data);
926 }
927
928 // The inputs 'cond', 'then', and 'else' must be broadcastable.
929 // TODO (yongtang): Consolidate 3-ary broadcast instead of
930 // multiple 2-ary broadcast.
931 ShapeHandle cond = c->input(0);
932 ShapeHandle then = c->input(1);
933 ShapeHandle else_ = c->input(2);
934 ShapeHandle other;
935 TF_RETURN_IF_ERROR(
936 BroadcastBinaryOpOutputShapeFnHelper(c, then, else_, true, &other));
937 ShapeHandle output;
938 TF_RETURN_IF_ERROR(
939 BroadcastBinaryOpOutputShapeFnHelper(c, cond, other, true, &output));
940 c->set_output(0, output);
941 return OkStatus();
942 });
943
944// --------------------------------------------------------------------------
945
946REGISTER_OP("MatMul")
947 .Input("a: T")
948 .Input("b: T")
949 .Output("product: T")
950 .Attr("transpose_a: bool = false")
951 .Attr("transpose_b: bool = false")
952 .Attr(
953 "T: {bfloat16, half, float, double, int32, int64, complex64, "
954 "complex128}")
955 .SetShapeFn(shape_inference::MatMulShape);
956
957#ifdef INTEL_MKL
958REGISTER_OP("_MklMatMul")
959 .Input("a: T")
960 .Input("b: T")
961 .Output("product: T")
962 .Attr("transpose_a: bool = false")
963 .Attr("transpose_b: bool = false")
964 .Attr("T: {bfloat16, float}")
965 .SetShapeFn(shape_inference::MatMulShape);
966#endif // INTEL_MKL
967
968REGISTER_OP("SparseMatMul")
969 .Input("a: Ta")
970 .Input("b: Tb")
971 .Output("product: float")
972 .Attr("transpose_a: bool = false")
973 .Attr("transpose_b: bool = false")
974 .Attr("a_is_sparse: bool = false")
975 .Attr("b_is_sparse: bool = false")
976 .Attr("Ta: {float, bfloat16} = DT_FLOAT")
977 .Attr("Tb: {float, bfloat16} = DT_FLOAT")
978 .SetShapeFn(shape_inference::MatMulShape);
979
980REGISTER_OP("_FusedMatMul")
981 .Input("a: T")
982 .Input("b: T")
983 .Input("args: num_args * T")
984 .Output("product: T")
985 .Attr("transpose_a: bool = false")
986 .Attr("transpose_b: bool = false")
987 .Attr("T: {bfloat16, half, float}")
988 .Attr("num_args: int >= 0")
989 .Attr("fused_ops: list(string) = []")
990 // Attributes for the FusedBatchNorm ----------- //
991 .Attr("epsilon: float = 0.0001")
992 // Attributes for the LeakyRelu ---------------- //
993 .Attr("leakyrelu_alpha: float = 0.2")
994 // --------------------------------------------- //
995 .SetShapeFn(shape_inference::MatMulShape)
996 .Doc(R"doc(
997Performs a MatMul followed by a specified series of operations.
998
999The inputs to the MatMul are specified by `a` and `b`. The series of operations
1000that follows is specified by the `fused_ops` attribute, which is a list of TF op
1001names specified as strings (e.g. "Relu"). They are performed in order, where the
1002(first) input to each op is the output of the preceding op. The first input and
1003the output of each fused_op must be of type T.
1004
1005Currently supported fused_op combinations are: ["BiasAdd"] and ["BiasAdd",A],
1006where A is one of {"Elu","Relu","Relu6"}.
1007
1008* The first input to BiasAdd is the MatMul result, and the additional BiasAdd
1009input is specified by `args`.
1010* If there is an op A specified, the output of the BiasAdd is the input to op A,
1011and op A produces the _FusedConv2D output. Otherwise, the BiasAdd produces the
1012_FusedConv2D output.
1013
1014*NOTE*: Do not invoke this operator directly in Python. Grappler is
1015expected to create these operators.
1016)doc");
1017
1018// --------------------------------------------------------------------------
1019
1020// For operations where the output is a reduction function along some
1021// dimensions of the input.
1022REGISTER_OP("Sum")
1023 .Input("input: T")
1024 .Input("reduction_indices: Tidx")
1025 .Output("output: T")
1026 .Attr("keep_dims: bool = false")
1027 .Attr("T: numbertype")
1028 .Attr("Tidx: {int32, int64} = DT_INT32")
1029 .SetShapeFn(shape_inference::ReductionShape);
1030
1031REGISTER_OP("EuclideanNorm")
1032 .Input("input: T")
1033 .Input("reduction_indices: Tidx")
1034 .Output("output: T")
1035 .Attr("keep_dims: bool = false")
1036 .Attr("T: numbertype")
1037 .Attr("Tidx: {int32, int64} = DT_INT32")
1038 .SetShapeFn(shape_inference::ReductionShape);
1039
1040REGISTER_OP("Mean")
1041 .Input("input: T")
1042 .Input("reduction_indices: Tidx")
1043 .Output("output: T")
1044 .Attr("keep_dims: bool = false")
1045 .Attr("T: numbertype")
1046 .Attr("Tidx: {int32, int64} = DT_INT32")
1047 .SetShapeFn(shape_inference::ReductionShape);
1048
1049REGISTER_OP("Prod")
1050 .Input("input: T")
1051 .Input("reduction_indices: Tidx")
1052 .Output("output: T")
1053 .Attr("keep_dims: bool = false")
1054 .Attr("T: numbertype")
1055 .Attr("Tidx: {int32, int64} = DT_INT32")
1056 .SetShapeFn(shape_inference::ReductionShape);
1057
1058REGISTER_OP("Min")
1059 .Input("input: T")
1060 .Input("reduction_indices: Tidx")
1061 .Output("output: T")
1062 .Attr("keep_dims: bool = false")
1063 .Attr("T: {realnumbertype, quantizedtype}")
1064 .Attr("Tidx: {int32, int64} = DT_INT32")
1065 .SetShapeFn(shape_inference::ReductionShape);
1066
1067REGISTER_OP("Max")
1068 .Input("input: T")
1069 .Input("reduction_indices: Tidx")
1070 .Output("output: T")
1071 .Attr("keep_dims: bool = false")
1072 .Attr("T: {realnumbertype, quantizedtype}")
1073 .Attr("Tidx: {int32, int64} = DT_INT32")
1074 .SetShapeFn(shape_inference::ReductionShape);
1075
1076namespace {
1077
1078Status ArgOpShape(shape_inference::InferenceContext* c) {
1079 ShapeHandle dimension_shape;
1080 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &dimension_shape));
1081
1082 ShapeHandle input_shape = c->input(0);
1083 if (!c->RankKnown(input_shape)) {
1084 return shape_inference::UnknownShape(c);
1085 }
1086
1087 const int32_t input_rank = c->Rank(input_shape);
1088 if (input_rank <= 1) {
1089 // Reducing a scalar/vector must return a scalar.
1090 return shape_inference::ScalarShape(c);
1091 }
1092
1093 const Tensor* dim_t = c->input_tensor(1);
1094 if (dim_t == nullptr) {
1095 // We don't know the value of the dimension, but we
1096 // know the rank of the input, so return the correct
1097 // rank with unknown dimensions.
1098 std::vector<DimensionHandle> dims(input_rank - 1);
1099 for (int i = 0; i < dims.size(); ++i) {
1100 dims[i] = c->UnknownDim();
1101 }
1102
1103 c->set_output(0, c->MakeShape(dims));
1104 return OkStatus();
1105 }
1106
1107 int64_t dimension_val;
1108 if (dim_t->dtype() == DT_INT32) {
1109 dimension_val = dim_t->scalar<int32>()();
1110 } else {
1111 dimension_val = dim_t->scalar<int64_t>()();
1112 }
1113
1114 int64_t axis = dimension_val < 0 ? dimension_val + input_rank : dimension_val;
1115 if (axis < 0 || axis >= input_rank) {
1116 return errors::InvalidArgument(
1117 "Dimension (", dimension_val, ") must be in the range [", -input_rank,
1118 ", ", input_rank, "), where ", input_rank,
1119 " is the number of dimensions in the input.");
1120 }
1121
1122 // Return the input shape without the dimension being reduced.
1123 std::vector<DimensionHandle> dims;
1124 for (int i = 0; i < input_rank; ++i) {
1125 if (axis != i) {
1126 dims.emplace_back(c->Dim(input_shape, i));
1127 }
1128 }
1129 c->set_output(0, c->MakeShape(dims));
1130 return OkStatus();
1131}
1132
1133} // namespace
1134
1135REGISTER_OP("ArgMax")
1136 .Input("input: T")
1137 .Input("dimension: Tidx")
1138 .Output("output: output_type")
1139 .Attr("T: {numbertype, bool}")
1140 .Attr("Tidx: {int16, int32, int64} = DT_INT32")
1141 .Attr("output_type: {int16, uint16, int32, int64} = DT_INT64")
1142 .SetShapeFn(ArgOpShape);
1143
1144REGISTER_OP("ArgMin")
1145 .Input("input: T")
1146 .Input("dimension: Tidx")
1147 .Output("output: output_type")
1148 .Attr("T: {numbertype, bool}")
1149 .Attr("Tidx: {int32, int64} = DT_INT32")
1150 .Attr("output_type: {int32, int64} = DT_INT64")
1151 .SetShapeFn(ArgOpShape);
1152
1153namespace {
1154
1155Status SegmentReductionShapeFn(InferenceContext* c) {
1156 ShapeHandle data_shape;
1157 ShapeHandle segment_ids_shape;
1158 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1159 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &segment_ids_shape));
1160
1161 ShapeHandle subshape;
1162 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1163
1164 ShapeHandle out;
1165 TF_RETURN_IF_ERROR(
1166 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1167 c->set_output(0, out);
1168 return OkStatus();
1169}
1170
1171Status SparseSegmentReductionShapeFn(InferenceContext* c) {
1172 ShapeHandle data_shape;
1173 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1174
1175 ShapeHandle indices_shape;
1176 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1177
1178 ShapeHandle segment_ids_shape;
1179 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1180
1181 // indices and segment_ids should merge cleanly.
1182 ShapeHandle unused;
1183 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1184
1185 ShapeHandle subshape;
1186 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1187
1188 ShapeHandle out;
1189 TF_RETURN_IF_ERROR(
1190 c->Concatenate(c->Vector(InferenceContext::kUnknownDim), subshape, &out));
1191 c->set_output(0, out);
1192 return OkStatus();
1193}
1194
1195Status SparseSegmentReductionGradShapeFn(InferenceContext* c) {
1196 ShapeHandle data_shape;
1197 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1198
1199 ShapeHandle indices_shape;
1200 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1201
1202 // indices and segment_ids should merge cleanly.
1203 ShapeHandle unused;
1204 TF_RETURN_IF_ERROR(c->Merge(c->input(2), indices_shape, &unused));
1205
1206 // output_dim0 should be a scalar
1207 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1208
1209 ShapeHandle subshape;
1210 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1211
1212 const Tensor* dim0 = c->input_tensor(3);
1213 ShapeHandle dim0_shape;
1214 if (dim0 == nullptr) {
1215 // We don't have the value at inference time, so the output
1216 // shape is unknown.
1217 dim0_shape = c->Vector(InferenceContext::kUnknownDim);
1218 } else {
1219 auto dim0_value = dim0->scalar<int32>()();
1220 if (dim0_value < 0) {
1221 return errors::InvalidArgument(
1222 "Cannot specify a negative value for output_dim0");
1223 }
1224 dim0_shape = c->Vector(dim0_value);
1225 }
1226
1227 ShapeHandle out;
1228 TF_RETURN_IF_ERROR(c->Concatenate(dim0_shape, subshape, &out));
1229 c->set_output(0, out);
1230 return OkStatus();
1231}
1232
1233Status SparseSegmentReductionWithNumSegmentsShapeFn(InferenceContext* c) {
1234 ShapeHandle data_shape;
1235 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &data_shape));
1236
1237 ShapeHandle indices_shape;
1238 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &indices_shape));
1239
1240 ShapeHandle segment_ids_shape;
1241 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &segment_ids_shape));
1242
1243 ShapeHandle num_segments_shape;
1244 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &num_segments_shape));
1245
1246 // indices and segment_ids should merge cleanly.
1247 ShapeHandle unused;
1248 TF_RETURN_IF_ERROR(c->Merge(indices_shape, segment_ids_shape, &unused));
1249
1250 ShapeHandle subshape;
1251 TF_RETURN_IF_ERROR(c->Subshape(data_shape, 1, &subshape));
1252
1253 ShapeHandle out;
1254 const Tensor* dim0 = c->input_tensor(3);
1255 if (dim0 == nullptr) {
1256 // We don't have the value at inference time, so the output
1257 // shape is unknown.
1258 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(InferenceContext::kUnknownDim),
1259 subshape, &out));
1260 } else {
1261 auto dim0_value = dim0->scalar<int32>()();
1262 if (dim0_value < 0) {
1263 return errors::InvalidArgument(
1264 "Cannot specify a negative value for num_segments");
1265 }
1266 TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(dim0_value), subshape, &out));
1267 }
1268 c->set_output(0, out);
1269 return OkStatus();
1270}
1271} // namespace
1272
1273REGISTER_OP("SegmentSum")
1274 .Input("data: T")
1275 .Input("segment_ids: Tindices")
1276 .Output("output: T")
1277 .Attr("T: numbertype")
1278 .Attr("Tindices: {int32,int64}")
1279 .SetShapeFn(SegmentReductionShapeFn);
1280
1281REGISTER_OP("SegmentMean")
1282 .Input("data: T")
1283 .Input("segment_ids: Tindices")
1284 .Output("output: T")
1285 .Attr("T: numbertype")
1286 .Attr("Tindices: {int32,int64}")
1287 .SetShapeFn(SegmentReductionShapeFn);
1288
1289REGISTER_OP("SegmentProd")
1290 .Input("data: T")
1291 .Input("segment_ids: Tindices")
1292 .Output("output: T")
1293 .Attr("T: numbertype")
1294 .Attr("Tindices: {int32,int64}")
1295 .SetShapeFn(SegmentReductionShapeFn);
1296
1297REGISTER_OP("SegmentMin")
1298 .Input("data: T")
1299 .Input("segment_ids: Tindices")
1300 .Output("output: T")
1301 .Attr("T: realnumbertype")
1302 .Attr("Tindices: {int32,int64}")
1303 .SetShapeFn(SegmentReductionShapeFn);
1304
1305REGISTER_OP("SegmentMax")
1306 .Input("data: T")
1307 .Input("segment_ids: Tindices")
1308 .Output("output: T")
1309 .Attr("T: realnumbertype")
1310 .Attr("Tindices: {int32,int64}")
1311 .SetShapeFn(SegmentReductionShapeFn);
1312
1313REGISTER_OP("UnsortedSegmentSum")
1314 .Input("data: T")
1315 .Input("segment_ids: Tindices")
1316 .Input("num_segments: Tnumsegments")
1317 .Output("output: T")
1318 .Attr("T: numbertype")
1319 .Attr("Tindices: {int32,int64}")
1320 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1321 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1322
1323REGISTER_OP("UnsortedSegmentMax")
1324 .Input("data: T")
1325 .Input("segment_ids: Tindices")
1326 .Input("num_segments: Tnumsegments")
1327 .Output("output: T")
1328 .Attr("T: realnumbertype")
1329 .Attr("Tindices: {int32,int64}")
1330 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1331 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1332
1333REGISTER_OP("UnsortedSegmentMin")
1334 .Input("data: T")
1335 .Input("segment_ids: Tindices")
1336 .Input("num_segments: Tnumsegments")
1337 .Output("output: T")
1338 .Attr("T: realnumbertype")
1339 .Attr("Tindices: {int32,int64}")
1340 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1341 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1342
1343REGISTER_OP("UnsortedSegmentProd")
1344 .Input("data: T")
1345 .Input("segment_ids: Tindices")
1346 .Input("num_segments: Tnumsegments")
1347 .Output("output: T")
1348 .Attr("T: numbertype")
1349 .Attr("Tindices: {int32,int64}")
1350 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1351 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
1352
1353REGISTER_OP("SparseSegmentSum")
1354 .Input("data: T")
1355 .Input("indices: Tidx")
1356 .Input("segment_ids: Tsegmentids")
1357 .Output("output: T")
1358 .Attr("T: realnumbertype")
1359 .Attr("Tidx: {int32, int64} = DT_INT32")
1360 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1361 .SetShapeFn(SparseSegmentReductionShapeFn);
1362
1363REGISTER_OP("SparseSegmentSumWithNumSegments")
1364 .Input("data: T")
1365 .Input("indices: Tidx")
1366 .Input("segment_ids: Tsegmentids")
1367 .Input("num_segments: Tnumsegments")
1368 .Output("output: T")
1369 .Attr("T: realnumbertype")
1370 .Attr("Tidx: {int32, int64} = DT_INT32")
1371 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1372 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1373 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1374
1375REGISTER_OP("SparseSegmentSumGrad")
1376 .Input("grad: T")
1377 .Input("indices: Tidx")
1378 .Input("segment_ids: Tsegmentids")
1379 .Input("output_dim0: int32")
1380 .Output("output: T")
1381 .Attr("T: {bfloat16, half, float, double}")
1382 .Attr("Tidx: {int32, int64} = DT_INT32")
1383 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1384 .SetShapeFn(SparseSegmentReductionGradShapeFn);
1385
1386REGISTER_OP("SparseSegmentMean")
1387 .Input("data: T")
1388 .Input("indices: Tidx")
1389 .Input("segment_ids: Tsegmentids")
1390 .Output("output: T")
1391 .Attr("T: {bfloat16, half, float, double}")
1392 .Attr("Tidx: {int32, int64} = DT_INT32")
1393 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1394 .SetShapeFn(SparseSegmentReductionShapeFn);
1395
1396REGISTER_OP("SparseSegmentMeanWithNumSegments")
1397 .Input("data: T")
1398 .Input("indices: Tidx")
1399 .Input("segment_ids: Tsegmentids")
1400 .Input("num_segments: Tnumsegments")
1401 .Output("output: T")
1402 .Attr("T: {bfloat16, half, float, double}")
1403 .Attr("Tidx: {int32, int64} = DT_INT32")
1404 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1405 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1406 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1407
1408REGISTER_OP("SparseSegmentMeanGrad")
1409 .Input("grad: T")
1410 .Input("indices: Tidx")
1411 .Input("segment_ids: Tsegmentids")
1412 .Input("output_dim0: int32")
1413 .Output("output: T")
1414 .Attr("T: {bfloat16, half, float, double}")
1415 .Attr("Tidx: {int32, int64} = DT_INT32")
1416 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1417 .SetShapeFn(SparseSegmentReductionGradShapeFn);
1418
1419REGISTER_OP("SparseSegmentSqrtN")
1420 .Input("data: T")
1421 .Input("indices: Tidx")
1422 .Input("segment_ids: Tsegmentids")
1423 .Output("output: T")
1424 .Attr("T: {bfloat16, half, float, double}")
1425 .Attr("Tidx: {int32, int64} = DT_INT32")
1426 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1427 .SetShapeFn(SparseSegmentReductionShapeFn);
1428
1429REGISTER_OP("SparseSegmentSqrtNWithNumSegments")
1430 .Input("data: T")
1431 .Input("indices: Tidx")
1432 .Input("segment_ids: Tsegmentids")
1433 .Input("num_segments: Tnumsegments")
1434 .Output("output: T")
1435 .Attr("T: {bfloat16, half, float, double}")
1436 .Attr("Tidx: {int32, int64} = DT_INT32")
1437 .Attr("Tnumsegments: {int32,int64} = DT_INT32")
1438 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1439 .SetShapeFn(SparseSegmentReductionWithNumSegmentsShapeFn);
1440
1441REGISTER_OP("SparseSegmentSqrtNGrad")
1442 .Input("grad: T")
1443 .Input("indices: Tidx")
1444 .Input("segment_ids: Tsegmentids")
1445 .Input("output_dim0: int32")
1446 .Output("output: T")
1447 .Attr("T: {bfloat16, half, float, double}")
1448 .Attr("Tidx: {int32, int64} = DT_INT32")
1449 .Attr("Tsegmentids: {int32, int64} = DT_INT32")
1450 .SetShapeFn(SparseSegmentReductionGradShapeFn);
1451
1452REGISTER_OP("All")
1453 .Input("input: bool")
1454 .Input("reduction_indices: Tidx")
1455 .Output("output: bool")
1456 .Attr("keep_dims: bool = false")
1457 .Attr("Tidx: {int32, int64} = DT_INT32")
1458 .SetShapeFn(shape_inference::ReductionShape);
1459
1460REGISTER_OP("Any")
1461 .Input("input: bool")
1462 .Input("reduction_indices: Tidx")
1463 .Attr("keep_dims: bool = false")
1464 .Output("output: bool")
1465 .Attr("Tidx: {int32, int64} = DT_INT32")
1466 .SetShapeFn(shape_inference::ReductionShape);
1467
1468// --------------------------------------------------------------------------
1469
1470namespace {
1471
1472template <typename T>
1473Status RangeSize(const Tensor* start_t, const Tensor* limit_t,
1474 const Tensor* delta_t, InferenceContext* const c) {
1475 T start = start_t->scalar<T>()();
1476 T limit = limit_t->scalar<T>()();
1477 T delta = delta_t->scalar<T>()();
1478 if (start > limit && delta > T(0)) {
1479 return errors::InvalidArgument(
1480 "Requires start <= limit when delta > 0: ", start, "/", limit);
1481 }
1482 if (start < limit && delta < T(0)) {
1483 return errors::InvalidArgument(
1484 "Requires start >= limit when delta < 0: ", start, "/", limit);
1485 }
1486 if (delta == T(0)) {
1487 return errors::InvalidArgument("Requires delta != 0");
1488 }
1489
1490 int64_t size;
1491 if (std::is_integral<T>::value) {
1492 size = Eigen::divup(static_cast<int64_t>(Eigen::numext::abs(limit - start)),
1493 static_cast<int64_t>(Eigen::numext::abs(delta)));
1494 } else {
1495 auto size_auto =
1496 Eigen::numext::ceil(Eigen::numext::abs((limit - start) / delta));
1497 if (size_auto > std::numeric_limits<int64_t>::max()) {
1498 return errors::InvalidArgument("Requires ((limit - start) / delta) <= ",
1499 std::numeric_limits<int64_t>::max());
1500 }
1501 size = static_cast<int64_t>(size_auto);
1502 }
1503
1504 c->set_output(0, c->Vector(static_cast<int64_t>(size)));
1505 return OkStatus();
1506}
1507
1508} // namespace
1509
1510REGISTER_OP("Range")
1511 .Input("start: Tidx")
1512 .Input("limit: Tidx")
1513 .Input("delta: Tidx")
1514 .Output("output: Tidx")
1515 .Attr(
1516 "Tidx: "
1517 "{bfloat16, half, float, double, int8, int16, int32, int64, uint16, "
1518 "uint32} = "
1519 "DT_INT32")
1520 .SetShapeFn([](InferenceContext* c) {
1521 ShapeHandle unused;
1522 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1523 " for 'start'");
1524 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1525 " for 'limit'");
1526 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1527 " for 'delta'");
1528 const Tensor* start_t = c->input_tensor(0);
1529 const Tensor* limit_t = c->input_tensor(1);
1530 const Tensor* delta_t = c->input_tensor(2);
1531 DataType dtype;
1532 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1533 if (start_t == nullptr || limit_t == nullptr || delta_t == nullptr) {
1534 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1535 return OkStatus();
1536 }
1537 if (dtype == DT_INT32) {
1538 return RangeSize<int32>(start_t, limit_t, delta_t, c);
1539 } else if (dtype == DT_INT16) {
1540 return RangeSize<int16>(start_t, limit_t, delta_t, c);
1541 } else if (dtype == DT_INT8) {
1542 return RangeSize<int8>(start_t, limit_t, delta_t, c);
1543 } else if (dtype == DT_INT64) {
1544 return RangeSize<int64_t>(start_t, limit_t, delta_t, c);
1545 } else if (dtype == DT_UINT16) {
1546 return RangeSize<uint16>(start_t, limit_t, delta_t, c);
1547 } else if (dtype == DT_UINT32) {
1548 return RangeSize<uint32>(start_t, limit_t, delta_t, c);
1549 } else if (dtype == DT_FLOAT) {
1550 return RangeSize<float>(start_t, limit_t, delta_t, c);
1551 } else if (dtype == DT_DOUBLE) {
1552 return RangeSize<double>(start_t, limit_t, delta_t, c);
1553 } else if (dtype == DT_BFLOAT16) {
1554 return RangeSize<bfloat16>(start_t, limit_t, delta_t, c);
1555 } else {
1556 return errors::InvalidArgument("Unsupported dtype", dtype);
1557 }
1558 return OkStatus();
1559 });
1560
1561REGISTER_OP("LinSpace")
1562 .Input("start: T")
1563 .Input("stop: T")
1564 .Input("num: Tidx")
1565 .Output("output: T")
1566 .Attr("T: {bfloat16, half, float, double}")
1567 .Attr("Tidx: {int32, int64} = DT_INT32")
1568 .SetShapeFn([](InferenceContext* c) {
1569 ShapeHandle unused;
1570 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
1571 " for 'start'");
1572 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(1), 0, &unused),
1573 " for 'stop'");
1574 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(2), 0, &unused),
1575 " for 'num'");
1576 const Tensor* num_t = c->input_tensor(2);
1577 if (num_t == nullptr) {
1578 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
1579 return OkStatus();
1580 }
1581
1582 int64_t num;
1583 if (num_t->dtype() == DT_INT32) {
1584 num = num_t->scalar<int32>()();
1585 } else {
1586 num = num_t->scalar<int64_t>()();
1587 }
1588 if (num <= 0) return errors::InvalidArgument("Requires num > 0: ", num);
1589 c->set_output(0, c->Vector(num));
1590 return OkStatus();
1591 });
1592
1593REGISTER_OP("Complex")
1594 .Input("real: T")
1595 .Input("imag: T")
1596 .Output("out: Tout")
1597 .Attr("T: {float, double} = DT_FLOAT")
1598 .Attr("Tout: {complex64, complex128} = DT_COMPLEX64")
1599 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
1600
1601REGISTER_OP("Real")
1602 .Input("input: T")
1603 .Output("output: Tout")
1604 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1605 .Attr("Tout: {float, double} = DT_FLOAT")
1606 .SetShapeFn(shape_inference::UnchangedShape);
1607
1608REGISTER_OP("Imag")
1609 .Input("input: T")
1610 .Output("output: Tout")
1611 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1612 .Attr("Tout: {float, double} = DT_FLOAT")
1613 .SetShapeFn(shape_inference::UnchangedShape);
1614
1615REGISTER_OP("Angle")
1616 .Input("input: T")
1617 .Output("output: Tout")
1618 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
1619 .Attr("Tout: {float, double} = DT_FLOAT")
1620 .SetShapeFn(shape_inference::UnchangedShape);
1621
1622REGISTER_OP("Conj")
1623 .Input("input: T")
1624 .Output("output: T")
1625 .Attr("T: {complex64, complex128, variant} = DT_COMPLEX64")
1626 .SetShapeFn([](InferenceContext* c) {
1627 c->set_output(0, c->input(0));
1628 auto* handle_data = c->input_handle_shapes_and_types(0);
1629 if (handle_data != nullptr) {
1630 c->set_output_handle_shapes_and_types(0, *handle_data);
1631 }
1632 return OkStatus();
1633 });
1634
1635// --------------------------------------------------------------------------
1636
1637REGISTER_OP("Cross")
1638 .Input("a: T")
1639 .Input("b: T")
1640 .Output("product: T")
1641 .Attr("T: realnumbertype")
1642 .SetShapeFn([](InferenceContext* c) {
1643 ShapeHandle a_shape;
1644 ShapeHandle b_shape;
1645 // * Input rank >= 1.
1646 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &a_shape));
1647 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &b_shape));
1648
1649 // * Both inputs have the same shape.
1650 TF_RETURN_IF_ERROR(c->Merge(a_shape, b_shape, &a_shape));
1651
1652 // * input_shape[-1] == 3.
1653 if (c->RankKnown(a_shape)) {
1654 int rank = c->Rank(a_shape);
1655 auto dim = c->Dim(a_shape, rank - 1);
1656 TF_RETURN_IF_ERROR(c->WithValue(dim, 3, &dim));
1657 }
1658 c->set_output(0, a_shape);
1659 return OkStatus();
1660 });
1661
1662// --------------------------------------------------------------------------
1663
1664REGISTER_OP("HistogramFixedWidth")
1665 .Input("values: T")
1666 .Input("value_range: T")
1667 .Input("nbins: int32")
1668 .Output("out: dtype")
1669 .Attr("T: {int32, int64, float32, float64}")
1670 .Attr("dtype: {int32, int64} = DT_INT32")
1671 .SetShapeFn([](InferenceContext* c) {
1672 // value_range should be a vector.
1673 ShapeHandle value_range_shape;
1674 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &value_range_shape));
1675 // value_range should have two elements.
1676 DimensionHandle unused;
1677 TF_RETURN_IF_ERROR(
1678 c->WithValue(c->Dim(value_range_shape, 0), 2, &unused));
1679 // nbins should be a scalar.
1680 ShapeHandle nbins_shape;
1681 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &nbins_shape));
1682
1683 // If nbins is available, set the shape from nbins.
1684 const Tensor* nbins_input = c->input_tensor(2);
1685 if (nbins_input != nullptr) {
1686 int64_t nbins;
1687 TF_RETURN_IF_ERROR(c->GetScalarFromTensor(nbins_input, &nbins));
1688 // nbins has to be positive.
1689 if (nbins <= 0) {
1690 return errors::InvalidArgument("Requires nbins > 0: ", nbins);
1691 }
1692 c->set_output(0, c->Vector(nbins));
1693 } else {
1694 c->set_output(0, c->UnknownShapeOfRank(1));
1695 }
1696 return OkStatus();
1697 });
1698
1699REGISTER_OP("Bincount")
1700 .Input("arr: int32")
1701 .Input("size: int32")
1702 .Input("weights: T")
1703 .Attr("T: {int32, int64, float32, float64}")
1704 .Output("bins: T")
1705 .SetShapeFn([](InferenceContext* c) {
1706 ShapeHandle unused;
1707 // The input `size` must be a scalar.
1708 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1709
1710 const Tensor* size_tensor = c->input_tensor(1);
1711 if (size_tensor == nullptr) {
1712 // Return unknown shape if size is not known.
1713 c->set_output(0, c->UnknownShapeOfRank(1));
1714 return OkStatus();
1715 }
1716
1717 if (size_tensor->dims() != 0) {
1718 return errors::InvalidArgument("Shape must be rank 0 but is rank ",
1719 size_tensor->dims());
1720 }
1721
1722 // Return `[size]` shape if size is known.
1723 int32_t size_val = size_tensor->scalar<int32>()();
1724 if (size_val < 0) {
1725 return errors::InvalidArgument("size (", size_val,
1726 ") must be non-negative");
1727 }
1728 c->set_output(0, c->MakeShape({size_val}));
1729 return OkStatus();
1730 });
1731
1732REGISTER_OP("DenseBincount")
1733 .Input("input: Tidx")
1734 .Input("size: Tidx")
1735 .Input("weights: T")
1736 .Attr("Tidx: {int32, int64}")
1737 .Attr("T: {int32, int64, float32, float64}")
1738 .Attr("binary_output: bool = false")
1739 .Output("output: T")
1740 .SetShapeFn([](InferenceContext* c) {
1741 ShapeHandle unused;
1742 // The input `input` must be at most matrix.
1743 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 2, &unused));
1744 // The input `size` must be a scalar.
1745 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1746
1747 const Tensor* size_tensor = c->input_tensor(1);
1748 if (size_tensor == nullptr) {
1749 // Return unknown shape if size is not known.
1750 c->set_output(0, c->UnknownShape());
1751 return OkStatus();
1752 }
1753 if (size_tensor->dims() != 0) {
1754 return errors::InvalidArgument("Shape must be rank 0 but is rank ",
1755 size_tensor->dims());
1756 }
1757
1758 int64_t size_val;
1759 DataType dtype;
1760 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1761 if (dtype == DT_INT32) {
1762 size_val = static_cast<int64_t>(size_tensor->scalar<int32>()());
1763 } else if (dtype == DT_INT64) {
1764 size_val = size_tensor->scalar<int64_t>()();
1765 } else {
1766 return errors::InvalidArgument("size dtype must be int32 or int64");
1767 }
1768 // Return `[size]` shape if size is known.
1769 if (size_val < 0) {
1770 return errors::InvalidArgument("size (", size_val,
1771 ") must be non-negative");
1772 }
1773 if (c->Rank(c->input(0)) == 1) {
1774 c->set_output(0, c->MakeShape({size_val}));
1775 } else if (c->Rank(c->input(0)) == 2) {
1776 c->set_output(0, c->MakeShape({c->Dim(c->input(0), 0), size_val}));
1777 }
1778 return OkStatus();
1779 });
1780
1781REGISTER_OP("SparseBincount")
1782 .Input("indices: int64")
1783 .Input("values: Tidx")
1784 .Input("dense_shape: int64")
1785 .Input("size: Tidx")
1786 .Input("weights: T")
1787 .Attr("Tidx: {int32, int64}")
1788 .Attr("T: {int32, int64, float32, float64}")
1789 .Attr("binary_output: bool = false")
1790 .Output("output: T")
1791 .SetShapeFn([](InferenceContext* c) {
1792 const Tensor* size_tensor = c->input_tensor(3);
1793 if (size_tensor == nullptr) {
1794 // Return unknown shape if size is not known.
1795 c->set_output(0, c->UnknownShape());
1796 return OkStatus();
1797 }
1798 if (size_tensor->dims() != 0) {
1799 return errors::InvalidArgument("Shape must be rank 0 but is rank ",
1800 size_tensor->dims());
1801 }
1802
1803 int64_t size_val;
1804 DataType dtype;
1805 TF_RETURN_IF_ERROR(c->GetAttr("Tidx", &dtype));
1806 if (dtype == DT_INT32) {
1807 size_val = static_cast<int64_t>(size_tensor->scalar<int32>()());
1808 } else if (dtype == DT_INT64) {
1809 size_val = size_tensor->scalar<int64_t>()();
1810 } else {
1811 return errors::InvalidArgument("size dtype must be int32 or int64");
1812 }
1813 // Return `[size]` shape if size is known.
1814 if (size_val < 0) {
1815 return errors::InvalidArgument("size (", size_val,
1816 ") must be non-negative");
1817 }
1818
1819 const Tensor* shape_tensor = c->input_tensor(2);
1820 if (shape_tensor == nullptr) {
1821 // Return unknown shape if size is not known.
1822 c->set_output(0, c->UnknownShape());
1823 return OkStatus();
1824 }
1825 if (shape_tensor->NumElements() == 1) {
1826 c->set_output(0, c->MakeShape({size_val}));
1827 } else if (shape_tensor->NumElements() == 2) {
1828 c->set_output(
1829 0, c->MakeShape({shape_tensor->flat<int64_t>()(0), size_val}));
1830 } else {
1831 return errors::InvalidArgument("Input must be less than rank 2");
1832 }
1833 return OkStatus();
1834 });
1835
1836REGISTER_OP("RaggedBincount")
1837 .Input("splits: int64")
1838 .Input("values: Tidx")
1839 .Input("size: Tidx")
1840 .Input("weights: T")
1841 .Attr("Tidx: {int32, int64}")
1842 .Attr("T: {int32, int64, float32, float64}")
1843 .Attr("binary_output: bool = false")
1844 .Output("output: T")
1845 .SetShapeFn([](InferenceContext* c) {
1846 c->set_output(0, c->UnknownShape());
1847 return OkStatus();
1848 });
1849
1850REGISTER_OP("Cumsum")
1851 .Input("x: T")
1852 .Input("axis: Tidx")
1853 .Attr("exclusive: bool = false")
1854 .Attr("reverse: bool = false")
1855 .Output("out: T")
1856 .Attr("T: numbertype")
1857 .Attr("Tidx: {int32, int64} = DT_INT32")
1858 .SetShapeFn(shape_inference::UnchangedShape);
1859
1860REGISTER_OP("Cumprod")
1861 .Input("x: T")
1862 .Input("axis: Tidx")
1863 .Attr("exclusive: bool = false")
1864 .Attr("reverse: bool = false")
1865 .Output("out: T")
1866 .Attr("T: numbertype")
1867 .Attr("Tidx: {int32, int64} = DT_INT32")
1868 .SetShapeFn(shape_inference::UnchangedShape);
1869
1870REGISTER_OP("CumulativeLogsumexp")
1871 .Input("x : T")
1872 .Input("axis: Tidx")
1873 .Attr("exclusive: bool = false")
1874 .Attr("reverse: bool = false")
1875 .Output("out: T")
1876 .Attr("T: {float16, float32, float64}")
1877 .Attr("Tidx: {int32, int64} = DT_INT32")
1878 .SetShapeFn(shape_inference::UnchangedShape);
1879
1880REGISTER_OP("QuantizedMatMul")
1881 .Input("a: T1")
1882 .Input("b: T2")
1883 .Input("min_a: float")
1884 .Input("max_a: float")
1885 .Input("min_b: float")
1886 .Input("max_b: float")
1887 .Output("out: Toutput")
1888 .Output("min_out: float")
1889 .Output("max_out: float")
1890 .Attr("T1: quantizedtype")
1891 .Attr("T2: quantizedtype")
1892 .Attr("Toutput: quantizedtype = DT_QINT32")
1893 .Attr("transpose_a: bool = false")
1894 .Attr("transpose_b: bool = false")
1895 .Attr("Tactivation: quantizedtype = DT_QUINT8")
1896 .SetShapeFn([](InferenceContext* c) {
1897 TF_RETURN_IF_ERROR(shape_inference::MatMulShape(c));
1898 ShapeHandle unused;
1899 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1900 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1901 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1902 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1903
1904 c->set_output(1, c->Scalar());
1905 c->set_output(2, c->Scalar());
1906 return OkStatus();
1907 });
1908
1909// Note: This op is not commutative w.r.t. to all its inputs.
1910REGISTER_OP("QuantizedMul")
1911 .Input("x: T1")
1912 .Input("y: T2")
1913 .Input("min_x: float")
1914 .Input("max_x: float")
1915 .Input("min_y: float")
1916 .Input("max_y: float")
1917 .Output("z: Toutput")
1918 .Output("min_z: float")
1919 .Output("max_z: float")
1920 .Attr("T1: quantizedtype")
1921 .Attr("T2: quantizedtype")
1922 .Attr("Toutput: quantizedtype = DT_QINT32")
1923 .SetShapeFn([](InferenceContext* c) {
1924 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1925 c->set_output(1, c->Scalar());
1926 c->set_output(2, c->Scalar());
1927 return OkStatus();
1928 });
1929
1930// Note: This op is not commutative w.r.t. to all its inputs.
1931REGISTER_OP("QuantizedAdd")
1932 .Input("x: T1")
1933 .Input("y: T2")
1934 .Input("min_x: float")
1935 .Input("max_x: float")
1936 .Input("min_y: float")
1937 .Input("max_y: float")
1938 .Output("z: Toutput")
1939 .Output("min_z: float")
1940 .Output("max_z: float")
1941 .Attr("T1: quantizedtype")
1942 .Attr("T2: quantizedtype")
1943 .Attr("Toutput: quantizedtype = DT_QINT32")
1944 .SetShapeFn([](InferenceContext* c) {
1945 TF_RETURN_IF_ERROR(shape_inference::BroadcastBinaryOpShapeFn(c));
1946 // min_x, max_x, min_y, max_y should be scalar.
1947 ShapeHandle unused;
1948 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1949 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1950 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1951 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1952
1953 c->set_output(1, c->Scalar());
1954 c->set_output(2, c->Scalar());
1955 return OkStatus();
1956 });
1957
1958REGISTER_OP("QuantizeDownAndShrinkRange")
1959 .Input("input: Tinput")
1960 .Input("input_min: float")
1961 .Input("input_max: float")
1962 .Output("output: out_type")
1963 .Output("output_min: float")
1964 .Output("output_max: float")
1965 .Attr("Tinput: quantizedtype")
1966 .Attr("out_type: quantizedtype")
1967 .SetShapeFn([](InferenceContext* c) {
1968 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1969 ShapeHandle unused;
1970 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1971 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1972 c->set_output(1, c->Scalar());
1973 c->set_output(2, c->Scalar());
1974 return OkStatus();
1975 });
1976
1977REGISTER_OP("Requantize")
1978 .Input("input: Tinput")
1979 .Input("input_min: float")
1980 .Input("input_max: float")
1981 .Input("requested_output_min: float")
1982 .Input("requested_output_max: float")
1983 .Output("output: out_type")
1984 .Output("output_min: float")
1985 .Output("output_max: float")
1986 .Attr("Tinput: quantizedtype")
1987 .Attr("out_type: quantizedtype")
1988 .SetShapeFn([](InferenceContext* c) {
1989 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1990 ShapeHandle unused;
1991 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1992 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1993 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1994 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1995 c->set_output(1, c->Scalar());
1996 c->set_output(2, c->Scalar());
1997 return OkStatus();
1998 });
1999
2000REGISTER_OP("RequantizationRange")
2001 .Input("input: Tinput")
2002 .Input("input_min: float")
2003 .Input("input_max: float")
2004 .Output("output_min: float")
2005 .Output("output_max: float")
2006 .Attr("Tinput: quantizedtype")
2007 .SetShapeFn([](InferenceContext* c) {
2008 ShapeHandle unused;
2009 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2010 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2011 c->set_output(0, c->Scalar());
2012 c->set_output(1, c->Scalar());
2013 return OkStatus();
2014 });
2015
2016// --------------------------------------------------------------------------
2017
2018REGISTER_OP("Bucketize")
2019 .Input("input: T")
2020 .Output("output: int32")
2021 .Attr("T: {int32, int64, float, double}")
2022 .Attr("boundaries: list(float)")
2023 .SetShapeFn(shape_inference::UnchangedShape);
2024
2025REGISTER_OP("ClipByValue")
2026 .Input("t: T")
2027 .Input("clip_value_min: T")
2028 .Input("clip_value_max: T")
2029 .Output("output: T")
2030 .Attr("T: numbertype")
2031 .SetShapeFn(shape_inference::UnchangedShape);
2032
2033#ifdef INTEL_MKL
2034// Note: This op is not commutative w.r.t. to all its inputs.
2035REGISTER_OP("_MklAddN")
2036 .Input("inputs: N * T")
2037 .Input("mkl_input: N * uint8")
2038 .Output("sum: T")
2039 .Output("mkl_sum: uint8")
2040 .Attr("N: int >= 1")
2041 .Attr("T: numbertype")
2042 .SetShapeFn([](InferenceContext* c) {
2043 ShapeHandle cur = c->input(c->num_inputs() - 1);
2044 for (int i = c->num_inputs() - 2; i >= 0; --i) {
2045 TF_RETURN_WITH_CONTEXT_IF_ERROR(c->Merge(c->input(i), cur, &cur),
2046 "From merging shape ", i,
2047 " with other shapes.");
2048 }
2049 c->set_output(0, cur);
2050 return Status::OK();
2051 })
2052 .Doc(R"doc(
2053Add two input tensors element wise using mkl kernel sum.
2054inputs: Must all be the same size and shape.
2055)doc");
2056
2057#endif // INTEL_MKL
2058
2059REGISTER_OP("RequantizePerChannel")
2060 .Input("input: T")
2061 .Input("input_min: float")
2062 .Input("input_max: float")
2063 .Input("requested_output_min: float")
2064 .Input("requested_output_max: float")
2065 .Output("output: out_type")
2066 .Output("output_min: float")
2067 .Output("output_max: float")
2068 .Attr("T: quantizedtype = DT_QINT32")
2069 .Attr("out_type: quantizedtype = DT_QUINT8")
2070 .SetShapeFn([](InferenceContext* c) {
2071 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
2072 ShapeHandle unused;
2073 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2074 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2075 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2076 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2077 c->set_output(1, c->Scalar());
2078 c->set_output(2, c->Scalar());
2079 return OkStatus();
2080 });
2081REGISTER_OP("RequantizationRangePerChannel")
2082 .Input("input: T")
2083 .Input("input_min: float")
2084 .Input("input_max: float")
2085 .Output("output_min: float")
2086 .Output("output_max: float")
2087 .Attr("T: quantizedtype = DT_QINT32")
2088 .Attr("clip_value_max: float")
2089 .SetShapeFn([](InferenceContext* c) {
2090 ShapeHandle unused;
2091 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
2092 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2093 c->set_output(0, c->Scalar());
2094 c->set_output(1, c->Scalar());
2095 return OkStatus();
2096 });
2097
2098REGISTER_OP("NextAfter")
2099 .Attr("T: {float64, float32} = DT_FLOAT")
2100 .Input("x1: T")
2101 .Input("x2: T")
2102 .Output("output: T")
2103 .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
2104
2105REGISTER_OP("SobolSample")
2106 .Input("dim: int32")
2107 .Input("num_results: int32")
2108 .Input("skip: int32")
2109 .Attr("dtype: {float, double} = DT_FLOAT")
2110 .Output("samples: dtype")
2111 .SetShapeFn([](shape_inference::InferenceContext* c) {
2112 ShapeHandle unused;
2113
2114 // inputs must be scalars
2115 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
2116 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
2117 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2118
2119 const Tensor* dim_t = c->input_tensor(0);
2120 const Tensor* num_results_t = c->input_tensor(1);
2121
2122 int32_t dim = dim_t == nullptr ? InferenceContext::kUnknownDim
2123 : dim_t->scalar<int32>()();
2124
2125 int32_t num_results = num_results_t == nullptr
2126 ? InferenceContext::kUnknownDim
2127 : num_results_t->scalar<int32>()();
2128
2129 c->set_output(0, c->Matrix(num_results, dim));
2130 return OkStatus();
2131 });
2132
2133} // namespace tensorflow
2134