1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/register_types.h"
20#include "tensorflow/core/framework/shape_inference.h"
21
22namespace tensorflow {
23
24namespace {
25Status RiscBinaryNonBroadcastOpShapeFn(shape_inference::InferenceContext* c) {
26 const auto rank = c->Rank(c->input(0));
27 if (rank != c->Rank(c->input(1))) {
28 return errors::InvalidArgument("Mismatch rank for input.");
29 }
30 for (int i = 0; i < rank; ++i) {
31 if (!c->ValueKnown(c->Dim(c->input(0), i)) ||
32 !c->ValueKnown(c->Dim(c->input(1), i))) {
33 continue;
34 }
35 if (c->Value(c->Dim(c->input(0), i)) != c->Value(c->Dim(c->input(1), i))) {
36 return errors::InvalidArgument("Mismatch shapes for input.");
37 }
38 }
39 c->set_output(0, c->input(0));
40 auto* handle_data = c->input_handle_shapes_and_types(0);
41 if (handle_data != nullptr) {
42 c->set_output_handle_shapes_and_types(0, *handle_data);
43 }
44 return OkStatus();
45}
46} // namespace
47
48REGISTER_OP("RiscAbs")
49 .Input("x: T")
50 .Output("y: T")
51 .Attr("T: {bfloat16, half, float, double}")
52 .SetShapeFn(shape_inference::UnchangedShape);
53
54REGISTER_OP("RiscAdd")
55 .Input("x: T")
56 .Input("y: T")
57 .Output("z: T")
58 .Attr("T: {bfloat16, half, float, double}")
59 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn)
60 .SetIsAggregate()
61 .SetIsCommutative();
62
63// TODO(b/178234771): retire this.
64REGISTER_OP("RiscBinaryArithmetic")
65 .Input("x: T")
66 .Input("y: T")
67 .Output("z: T")
68 .Attr("op_type: {'ADD', 'SUB', 'MUL', 'DIV', 'REM', 'MIN', 'POW'}")
69 .Attr("T: {bfloat16, half, float, double}")
70 .SetShapeFn(shape_inference::UnchangedShape);
71
72REGISTER_OP("RiscBinaryComparison")
73 .Input("x: T")
74 .Input("y: T")
75 .Output("z: bool")
76 .Attr("op_type: {'EQ', 'NE', 'GE', 'GT', 'LE', 'LT'}")
77 .Attr("T: {bfloat16, half, float, double}")
78 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
79
80// TODO(b/178234771): change shape function.
81REGISTER_OP("RiscBitcast")
82 .Input("x: SrcT")
83 .Output("y: DstT")
84 .Attr("SrcT: type")
85 .Attr("DstT: type")
86 .SetShapeFn(shape_inference::UnknownShape);
87
88// TODO(b/178234771): change shape function.
89REGISTER_OP("RiscBroadcast")
90 .Input("input: T")
91 .Input("shape: Tidx")
92 .Output("output: T")
93 .Attr("T: type")
94 .Attr("Tidx: {int32, int64} = DT_INT32")
95 .SetShapeFn(shape_inference::UnknownShape);
96
97REGISTER_OP("RiscCast")
98 .Input("x: SrcT")
99 .Output("y: DstT")
100 .Attr("SrcT: type")
101 .Attr("DstT: type")
102 .SetShapeFn(shape_inference::UnchangedShape);
103
104REGISTER_OP("RiscCeil")
105 .Input("x: T")
106 .Output("y: T")
107 .Attr("T: {bfloat16, half, float, double}")
108 .SetShapeFn(shape_inference::UnchangedShape);
109
110// TODO(b/178234771): change shape function.
111REGISTER_OP("RiscCholesky")
112 .Input("input: T")
113 .Output("output: T")
114 .Attr("T: {bfloat16, half, float, double}")
115 .SetShapeFn(shape_inference::UnknownShape);
116
117REGISTER_OP("RiscConcat")
118 .Input("values: N * T")
119 .Input("axis: Tidx")
120 .Output("output: T")
121 .Attr("N: int >= 2")
122 .Attr("T: type")
123 .Attr("Tidx: {int32, int64} = DT_INT32")
124 .SetShapeFn(shape_inference::ConcatV2Shape);
125
126// TODO(b/178234771): change shape function.
127REGISTER_OP("RiscCondition")
128 .Input("pred: bool")
129 .Input("input_true: SrcT")
130 .Input("input_false: SrcT")
131 .Output("output: DstT")
132 .Attr("func_true: func")
133 .Attr("func_false: func")
134 .Attr("SrcT: {bfloat16, half, float, double}")
135 .Attr("DstT: {bfloat16, half, float, double}")
136 .SetShapeFn(shape_inference::UnknownShape);
137
138// TODO(b/178234771): change shape function.
139REGISTER_OP("RiscConv")
140 .Input("input: T")
141 .Input("filter: T")
142 .Output("output: T")
143 .Attr("T: {bfloat16, half, float, double}")
144 .Attr("strides: list(int)")
145 .Attr(GetConvnetDataFormatAttrString())
146 .SetShapeFn(shape_inference::UnknownShape)
147 .Attr("dilations: list(int) = [1, 1, 1, 1]");
148
149REGISTER_OP("RiscCos")
150 .Input("x: T")
151 .Output("y: T")
152 .Attr("T: {bfloat16, half, float, double}")
153 .SetShapeFn(shape_inference::UnchangedShape);
154
155REGISTER_OP("RiscDiv")
156 .Input("x: T")
157 .Input("y: T")
158 .Output("z: T")
159 .Attr("T: {bfloat16, half, float, double}")
160 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
161
162REGISTER_OP("RiscDot")
163 .Input("a: T")
164 .Input("b: T")
165 .Output("product: T")
166 .Attr("transpose_a: bool = false")
167 .Attr("transpose_b: bool = false")
168 .Attr("T: {bfloat16, half, float, double}")
169 .SetShapeFn(shape_inference::MatMulShape);
170
171REGISTER_OP("RiscExp")
172 .Input("x: T")
173 .Output("y: T")
174 .Attr("T: {bfloat16, half, float, double}")
175 .SetShapeFn(shape_inference::UnchangedShape);
176
177// TODO(b/178234771): change shape function.
178REGISTER_OP("RiscFft")
179 .Input("input: Tcomplex")
180 .Output("output: Tcomplex")
181 .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64")
182 .SetShapeFn(shape_inference::UnknownShape);
183
184REGISTER_OP("RiscFloor")
185 .Input("x: T")
186 .Output("y: T")
187 .Attr("T: {bfloat16, half, float, double}")
188 .SetShapeFn(shape_inference::UnchangedShape);
189
190// TODO(b/178234771): change shape function.
191REGISTER_OP("RiscGather")
192 .Input("params: Tparams")
193 .Input("indices: Tindices")
194 .Input("axis: Taxis")
195 .Attr("batch_dims: int = 0")
196 .Output("output: Tparams")
197 .Attr("Tparams: type")
198 .Attr("Tindices: {int32,int64}")
199 .Attr("Taxis: {int32,int64}")
200 .SetShapeFn(shape_inference::UnknownShape);
201
202REGISTER_OP("RiscImag")
203 .Input("input: T")
204 .Output("output: Tout")
205 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
206 .Attr("Tout: {float, double} = DT_FLOAT")
207 .SetShapeFn(shape_inference::UnchangedShape);
208
209REGISTER_OP("RiscIsFinite")
210 .Input("x: T")
211 .Output("y: bool")
212 .Attr("T: {bfloat16, half, float, double}")
213 .SetShapeFn(shape_inference::UnchangedShape);
214
215REGISTER_OP("RiscLog")
216 .Input("x: T")
217 .Output("y: T")
218 .Attr("T: {bfloat16, half, float, double}")
219 .SetShapeFn(shape_inference::UnchangedShape);
220
221// TODO(b/178234771): change shape function.
222REGISTER_OP("RiscLogicalAnd")
223 .Input("x: bool")
224 .Input("y: bool")
225 .Output("z: bool")
226 .SetShapeFn(shape_inference::UnknownShape);
227
228REGISTER_OP("RiscLogicalNot")
229 .Input("x: bool")
230 .Output("z: bool")
231 .SetShapeFn(shape_inference::UnchangedShape);
232
233// TODO(b/178234771): change shape function.
234REGISTER_OP("RiscLogicalOr")
235 .Input("x: bool")
236 .Input("y: bool")
237 .Output("z: bool")
238 .SetShapeFn(shape_inference::UnknownShape);
239
240REGISTER_OP("RiscMax")
241 .Input("x: T")
242 .Input("y: T")
243 .Output("max: T")
244 .Attr("T: {bfloat16, half, float, double}")
245 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
246
247REGISTER_OP("RiscMin")
248 .Input("x: T")
249 .Input("y: T")
250 .Output("z: T")
251 .Attr("T: {bfloat16, half, float, double}")
252 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
253
254REGISTER_OP("RiscMul")
255 .Input("x: T")
256 .Input("y: T")
257 .Output("z: T")
258 .Attr("T: {bfloat16, half, float, double}")
259 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
260
261REGISTER_OP("RiscNeg")
262 .Input("x: T")
263 .Output("y: T")
264 .Attr("T: {bfloat16, half, float, double}")
265 .SetShapeFn(shape_inference::UnchangedShape);
266
267// TODO(b/178234771): change shape function.
268REGISTER_OP("RiscPad")
269 .Input("input: T")
270 .Input("paddings: Tpaddings")
271 .Input("constant_values: T")
272 .Output("output: T")
273 .Attr("T: {bfloat16, half, float, double}")
274 .Attr("Tpaddings: {int32, int64} = DT_INT32")
275 .SetShapeFn(shape_inference::UnknownShape);
276
277// TODO(b/178234771): change shape function.
278REGISTER_OP("RiscPool")
279 .Input("value: T")
280 .Output("output: T")
281 .Attr("ksize: list(int) >= 4")
282 .Attr("strides: list(int) >= 4")
283 .Attr("pooling_type: {'AVG', 'MAX'}")
284 .Attr(GetConvnetDataFormatAttrString())
285 .Attr("T: {bfloat16, half, float, double}")
286 .SetShapeFn(shape_inference::UnknownShape);
287
288REGISTER_OP("RiscPow")
289 .Input("x: T")
290 .Input("y: T")
291 .Output("z: T")
292 .Attr("T: {bfloat16, half, float, double}")
293 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
294
295REGISTER_OP("RiscRandomUniform")
296 .Input("shape: T")
297 .Output("output: float")
298 .Attr("seed: int = 0")
299 .Attr("T: {int32, int64}")
300 .SetShapeFn(shape_inference::RandomShape);
301
302REGISTER_OP("RiscReal")
303 .Input("input: T")
304 .Output("output: Tout")
305 .Attr("T: {complex64, complex128} = DT_COMPLEX64")
306 .Attr("Tout: {float, double} = DT_FLOAT")
307 .SetShapeFn(shape_inference::UnchangedShape);
308
309// TODO(b/178234771): change shape function.
310REGISTER_OP("RiscReduce")
311 .Input("tensor: T")
312 .Input("axis: Index")
313 .Output("output: T")
314 .Attr("reduce_type: {'MEAN', 'SUM'}")
315 .Attr("Index: {int32,int64} = DT_INT32")
316 .Attr("T: {bfloat16, half, float, double}")
317 .SetShapeFn(shape_inference::UnknownShape);
318
319REGISTER_OP("RiscRem")
320 .Input("x: T")
321 .Input("y: T")
322 .Output("z: T")
323 .Attr("T: {bfloat16, half, float, double}")
324 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
325
326// TODO(b/178234771): change shape function.
327REGISTER_OP("RiscReshape")
328 .Input("tensor: T")
329 .Input("shape: Tshape")
330 .Output("output: T")
331 .Attr("T: {bfloat16, half, float, double}")
332 .Attr("Tshape: {int32, int64} = DT_INT32")
333 .SetShapeFn(shape_inference::UnknownShape);
334
335// TODO(b/178234771): change shape function.
336REGISTER_OP("RiscReverse")
337 .Input("tensor: T")
338 .Input("axis: Tidx")
339 .Output("output: T")
340 .Attr("Tidx: {int32, int64} = DT_INT32")
341 .Attr("T: {bfloat16, half, float, double}")
342 .SetShapeFn(shape_inference::UnknownShape);
343
344// TODO(b/178234771): change shape function.
345REGISTER_OP("RiscScatter")
346 .Input("indices: Tindices")
347 .Input("updates: T")
348 .Input("shape: Tindices")
349 .Output("output: T")
350 .Attr("T: {bfloat16, half, float, double}")
351 .Attr("Tindices: {int32, int64}")
352 .SetShapeFn(shape_inference::UnknownShape);
353
354// TODO(b/178234771): change shape function.
355REGISTER_OP("RiscShape")
356 .Input("input: T")
357 .Output("output: out_type")
358 .Attr("T: {bfloat16, half, float, double}")
359 .Attr("out_type: {int32, int64} = DT_INT32")
360 .SetShapeFn(shape_inference::UnknownShape);
361
362REGISTER_OP("RiscSign")
363 .Input("x: T")
364 .Output("y: T")
365 .Attr("T: {bfloat16, half, float, double}")
366 .SetShapeFn(shape_inference::UnchangedShape);
367
368REGISTER_OP("RiscSlice")
369 .Input("input: T")
370 .Input("begin: Index")
371 .Input("size: Index")
372 .Output("output: T")
373 .Attr("T: {bfloat16, half, float, double}")
374 .Attr("Index: {int32,int64}")
375 .SetShapeFn(shape_inference::SliceShape);
376
377REGISTER_OP("RiscSort")
378 .Input("input: T")
379 .Input("axis: Index")
380 .Output("output: T")
381 .Attr("Index: {int32,int64} = DT_INT32")
382 .Attr("T: {bfloat16, half, float, double}")
383 .Attr("direction: {'ASCENDING', 'DESCENDING'}")
384 .SetShapeFn(shape_inference::UnchangedShape);
385
386// TODO(b/178234771): change shape function.
387REGISTER_OP("RiscSqueeze")
388 .Input("input: T")
389 .Output("output: T")
390 .Attr("T: type")
391 .Attr("squeeze_dims: list(int) >= 0 = []")
392 .SetShapeFn(shape_inference::UnknownShape);
393
394REGISTER_OP("RiscSub")
395 .Input("x: T")
396 .Input("y: T")
397 .Output("z: T")
398 .Attr("T: {bfloat16, half, float, double}")
399 .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn);
400
401// TODO(b/178234771): change shape function.
402REGISTER_OP("RiscTranspose")
403 .Input("x: T")
404 .Input("perm: Tperm")
405 .Output("y: T")
406 .Attr("T: type")
407 .Attr("Tperm: {int32, int64} = DT_INT32")
408 .SetShapeFn(shape_inference::UnknownShape);
409
410// TODO(b/178234771): change shape function.
411REGISTER_OP("RiscTriangularSolve")
412 .Input("matrix: T")
413 .Input("rhs: T")
414 .Output("output: T")
415 .Attr("lower: bool = True")
416 .Attr("adjoint: bool = False")
417 .Attr("T: {bfloat16, half, float, double}")
418 .SetShapeFn(shape_inference::UnknownShape);
419
420// TODO(b/178234771): retire this.
421REGISTER_OP("RiscUnary")
422 .Input("x: T")
423 .Output("y: T")
424 .Attr(
425 "op_type: {'ABL', 'CEIL', 'COS', 'EXP', 'FLOOR', 'IMAG', 'LOG', 'NEG', "
426 "'REAL', 'SIGN'}")
427 .Attr("T: {bfloat16, half, float, double}")
428 .SetShapeFn(shape_inference::UnchangedShape);
429
430// TODO(b/178234771): change shape function.
431REGISTER_OP("RiscWhile")
432 .Input("input: T")
433 .Output("output: T")
434 .Attr("T: list(type) >= 0")
435 .Attr("cond: func")
436 .Attr("body: func")
437 .Attr("output_shapes: list(shape) = []")
438 .Attr("parallel_iterations: int = 10")
439 .SetIsStateful()
440 .SetShapeFn(shape_inference::UnknownShape);
441
442} // namespace tensorflow
443