1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/register_types.h" |
20 | #include "tensorflow/core/framework/shape_inference.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | namespace { |
25 | Status RiscBinaryNonBroadcastOpShapeFn(shape_inference::InferenceContext* c) { |
26 | const auto rank = c->Rank(c->input(0)); |
27 | if (rank != c->Rank(c->input(1))) { |
28 | return errors::InvalidArgument("Mismatch rank for input." ); |
29 | } |
30 | for (int i = 0; i < rank; ++i) { |
31 | if (!c->ValueKnown(c->Dim(c->input(0), i)) || |
32 | !c->ValueKnown(c->Dim(c->input(1), i))) { |
33 | continue; |
34 | } |
35 | if (c->Value(c->Dim(c->input(0), i)) != c->Value(c->Dim(c->input(1), i))) { |
36 | return errors::InvalidArgument("Mismatch shapes for input." ); |
37 | } |
38 | } |
39 | c->set_output(0, c->input(0)); |
40 | auto* handle_data = c->input_handle_shapes_and_types(0); |
41 | if (handle_data != nullptr) { |
42 | c->set_output_handle_shapes_and_types(0, *handle_data); |
43 | } |
44 | return OkStatus(); |
45 | } |
46 | } // namespace |
47 | |
48 | REGISTER_OP("RiscAbs" ) |
49 | .Input("x: T" ) |
50 | .Output("y: T" ) |
51 | .Attr("T: {bfloat16, half, float, double}" ) |
52 | .SetShapeFn(shape_inference::UnchangedShape); |
53 | |
54 | REGISTER_OP("RiscAdd" ) |
55 | .Input("x: T" ) |
56 | .Input("y: T" ) |
57 | .Output("z: T" ) |
58 | .Attr("T: {bfloat16, half, float, double}" ) |
59 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn) |
60 | .SetIsAggregate() |
61 | .SetIsCommutative(); |
62 | |
63 | // TODO(b/178234771): retire this. |
64 | REGISTER_OP("RiscBinaryArithmetic" ) |
65 | .Input("x: T" ) |
66 | .Input("y: T" ) |
67 | .Output("z: T" ) |
68 | .Attr("op_type: {'ADD', 'SUB', 'MUL', 'DIV', 'REM', 'MIN', 'POW'}" ) |
69 | .Attr("T: {bfloat16, half, float, double}" ) |
70 | .SetShapeFn(shape_inference::UnchangedShape); |
71 | |
72 | REGISTER_OP("RiscBinaryComparison" ) |
73 | .Input("x: T" ) |
74 | .Input("y: T" ) |
75 | .Output("z: bool" ) |
76 | .Attr("op_type: {'EQ', 'NE', 'GE', 'GT', 'LE', 'LT'}" ) |
77 | .Attr("T: {bfloat16, half, float, double}" ) |
78 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn); |
79 | |
80 | // TODO(b/178234771): change shape function. |
81 | REGISTER_OP("RiscBitcast" ) |
82 | .Input("x: SrcT" ) |
83 | .Output("y: DstT" ) |
84 | .Attr("SrcT: type" ) |
85 | .Attr("DstT: type" ) |
86 | .SetShapeFn(shape_inference::UnknownShape); |
87 | |
88 | // TODO(b/178234771): change shape function. |
89 | REGISTER_OP("RiscBroadcast" ) |
90 | .Input("input: T" ) |
91 | .Input("shape: Tidx" ) |
92 | .Output("output: T" ) |
93 | .Attr("T: type" ) |
94 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
95 | .SetShapeFn(shape_inference::UnknownShape); |
96 | |
97 | REGISTER_OP("RiscCast" ) |
98 | .Input("x: SrcT" ) |
99 | .Output("y: DstT" ) |
100 | .Attr("SrcT: type" ) |
101 | .Attr("DstT: type" ) |
102 | .SetShapeFn(shape_inference::UnchangedShape); |
103 | |
104 | REGISTER_OP("RiscCeil" ) |
105 | .Input("x: T" ) |
106 | .Output("y: T" ) |
107 | .Attr("T: {bfloat16, half, float, double}" ) |
108 | .SetShapeFn(shape_inference::UnchangedShape); |
109 | |
110 | // TODO(b/178234771): change shape function. |
111 | REGISTER_OP("RiscCholesky" ) |
112 | .Input("input: T" ) |
113 | .Output("output: T" ) |
114 | .Attr("T: {bfloat16, half, float, double}" ) |
115 | .SetShapeFn(shape_inference::UnknownShape); |
116 | |
117 | REGISTER_OP("RiscConcat" ) |
118 | .Input("values: N * T" ) |
119 | .Input("axis: Tidx" ) |
120 | .Output("output: T" ) |
121 | .Attr("N: int >= 2" ) |
122 | .Attr("T: type" ) |
123 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
124 | .SetShapeFn(shape_inference::ConcatV2Shape); |
125 | |
126 | // TODO(b/178234771): change shape function. |
127 | REGISTER_OP("RiscCondition" ) |
128 | .Input("pred: bool" ) |
129 | .Input("input_true: SrcT" ) |
130 | .Input("input_false: SrcT" ) |
131 | .Output("output: DstT" ) |
132 | .Attr("func_true: func" ) |
133 | .Attr("func_false: func" ) |
134 | .Attr("SrcT: {bfloat16, half, float, double}" ) |
135 | .Attr("DstT: {bfloat16, half, float, double}" ) |
136 | .SetShapeFn(shape_inference::UnknownShape); |
137 | |
138 | // TODO(b/178234771): change shape function. |
139 | REGISTER_OP("RiscConv" ) |
140 | .Input("input: T" ) |
141 | .Input("filter: T" ) |
142 | .Output("output: T" ) |
143 | .Attr("T: {bfloat16, half, float, double}" ) |
144 | .Attr("strides: list(int)" ) |
145 | .Attr(GetConvnetDataFormatAttrString()) |
146 | .SetShapeFn(shape_inference::UnknownShape) |
147 | .Attr("dilations: list(int) = [1, 1, 1, 1]" ); |
148 | |
149 | REGISTER_OP("RiscCos" ) |
150 | .Input("x: T" ) |
151 | .Output("y: T" ) |
152 | .Attr("T: {bfloat16, half, float, double}" ) |
153 | .SetShapeFn(shape_inference::UnchangedShape); |
154 | |
155 | REGISTER_OP("RiscDiv" ) |
156 | .Input("x: T" ) |
157 | .Input("y: T" ) |
158 | .Output("z: T" ) |
159 | .Attr("T: {bfloat16, half, float, double}" ) |
160 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn); |
161 | |
162 | REGISTER_OP("RiscDot" ) |
163 | .Input("a: T" ) |
164 | .Input("b: T" ) |
165 | .Output("product: T" ) |
166 | .Attr("transpose_a: bool = false" ) |
167 | .Attr("transpose_b: bool = false" ) |
168 | .Attr("T: {bfloat16, half, float, double}" ) |
169 | .SetShapeFn(shape_inference::MatMulShape); |
170 | |
171 | REGISTER_OP("RiscExp" ) |
172 | .Input("x: T" ) |
173 | .Output("y: T" ) |
174 | .Attr("T: {bfloat16, half, float, double}" ) |
175 | .SetShapeFn(shape_inference::UnchangedShape); |
176 | |
177 | // TODO(b/178234771): change shape function. |
178 | REGISTER_OP("RiscFft" ) |
179 | .Input("input: Tcomplex" ) |
180 | .Output("output: Tcomplex" ) |
181 | .Attr("Tcomplex: {complex64, complex128} = DT_COMPLEX64" ) |
182 | .SetShapeFn(shape_inference::UnknownShape); |
183 | |
184 | REGISTER_OP("RiscFloor" ) |
185 | .Input("x: T" ) |
186 | .Output("y: T" ) |
187 | .Attr("T: {bfloat16, half, float, double}" ) |
188 | .SetShapeFn(shape_inference::UnchangedShape); |
189 | |
190 | // TODO(b/178234771): change shape function. |
191 | REGISTER_OP("RiscGather" ) |
192 | .Input("params: Tparams" ) |
193 | .Input("indices: Tindices" ) |
194 | .Input("axis: Taxis" ) |
195 | .Attr("batch_dims: int = 0" ) |
196 | .Output("output: Tparams" ) |
197 | .Attr("Tparams: type" ) |
198 | .Attr("Tindices: {int32,int64}" ) |
199 | .Attr("Taxis: {int32,int64}" ) |
200 | .SetShapeFn(shape_inference::UnknownShape); |
201 | |
202 | REGISTER_OP("RiscImag" ) |
203 | .Input("input: T" ) |
204 | .Output("output: Tout" ) |
205 | .Attr("T: {complex64, complex128} = DT_COMPLEX64" ) |
206 | .Attr("Tout: {float, double} = DT_FLOAT" ) |
207 | .SetShapeFn(shape_inference::UnchangedShape); |
208 | |
209 | REGISTER_OP("RiscIsFinite" ) |
210 | .Input("x: T" ) |
211 | .Output("y: bool" ) |
212 | .Attr("T: {bfloat16, half, float, double}" ) |
213 | .SetShapeFn(shape_inference::UnchangedShape); |
214 | |
215 | REGISTER_OP("RiscLog" ) |
216 | .Input("x: T" ) |
217 | .Output("y: T" ) |
218 | .Attr("T: {bfloat16, half, float, double}" ) |
219 | .SetShapeFn(shape_inference::UnchangedShape); |
220 | |
221 | // TODO(b/178234771): change shape function. |
222 | REGISTER_OP("RiscLogicalAnd" ) |
223 | .Input("x: bool" ) |
224 | .Input("y: bool" ) |
225 | .Output("z: bool" ) |
226 | .SetShapeFn(shape_inference::UnknownShape); |
227 | |
228 | REGISTER_OP("RiscLogicalNot" ) |
229 | .Input("x: bool" ) |
230 | .Output("z: bool" ) |
231 | .SetShapeFn(shape_inference::UnchangedShape); |
232 | |
233 | // TODO(b/178234771): change shape function. |
234 | REGISTER_OP("RiscLogicalOr" ) |
235 | .Input("x: bool" ) |
236 | .Input("y: bool" ) |
237 | .Output("z: bool" ) |
238 | .SetShapeFn(shape_inference::UnknownShape); |
239 | |
240 | REGISTER_OP("RiscMax" ) |
241 | .Input("x: T" ) |
242 | .Input("y: T" ) |
243 | .Output("max: T" ) |
244 | .Attr("T: {bfloat16, half, float, double}" ) |
245 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn); |
246 | |
247 | REGISTER_OP("RiscMin" ) |
248 | .Input("x: T" ) |
249 | .Input("y: T" ) |
250 | .Output("z: T" ) |
251 | .Attr("T: {bfloat16, half, float, double}" ) |
252 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn); |
253 | |
254 | REGISTER_OP("RiscMul" ) |
255 | .Input("x: T" ) |
256 | .Input("y: T" ) |
257 | .Output("z: T" ) |
258 | .Attr("T: {bfloat16, half, float, double}" ) |
259 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn); |
260 | |
261 | REGISTER_OP("RiscNeg" ) |
262 | .Input("x: T" ) |
263 | .Output("y: T" ) |
264 | .Attr("T: {bfloat16, half, float, double}" ) |
265 | .SetShapeFn(shape_inference::UnchangedShape); |
266 | |
267 | // TODO(b/178234771): change shape function. |
268 | REGISTER_OP("RiscPad" ) |
269 | .Input("input: T" ) |
270 | .Input("paddings: Tpaddings" ) |
271 | .Input("constant_values: T" ) |
272 | .Output("output: T" ) |
273 | .Attr("T: {bfloat16, half, float, double}" ) |
274 | .Attr("Tpaddings: {int32, int64} = DT_INT32" ) |
275 | .SetShapeFn(shape_inference::UnknownShape); |
276 | |
277 | // TODO(b/178234771): change shape function. |
278 | REGISTER_OP("RiscPool" ) |
279 | .Input("value: T" ) |
280 | .Output("output: T" ) |
281 | .Attr("ksize: list(int) >= 4" ) |
282 | .Attr("strides: list(int) >= 4" ) |
283 | .Attr("pooling_type: {'AVG', 'MAX'}" ) |
284 | .Attr(GetConvnetDataFormatAttrString()) |
285 | .Attr("T: {bfloat16, half, float, double}" ) |
286 | .SetShapeFn(shape_inference::UnknownShape); |
287 | |
288 | REGISTER_OP("RiscPow" ) |
289 | .Input("x: T" ) |
290 | .Input("y: T" ) |
291 | .Output("z: T" ) |
292 | .Attr("T: {bfloat16, half, float, double}" ) |
293 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn); |
294 | |
295 | REGISTER_OP("RiscRandomUniform" ) |
296 | .Input("shape: T" ) |
297 | .Output("output: float" ) |
298 | .Attr("seed: int = 0" ) |
299 | .Attr("T: {int32, int64}" ) |
300 | .SetShapeFn(shape_inference::RandomShape); |
301 | |
302 | REGISTER_OP("RiscReal" ) |
303 | .Input("input: T" ) |
304 | .Output("output: Tout" ) |
305 | .Attr("T: {complex64, complex128} = DT_COMPLEX64" ) |
306 | .Attr("Tout: {float, double} = DT_FLOAT" ) |
307 | .SetShapeFn(shape_inference::UnchangedShape); |
308 | |
309 | // TODO(b/178234771): change shape function. |
310 | REGISTER_OP("RiscReduce" ) |
311 | .Input("tensor: T" ) |
312 | .Input("axis: Index" ) |
313 | .Output("output: T" ) |
314 | .Attr("reduce_type: {'MEAN', 'SUM'}" ) |
315 | .Attr("Index: {int32,int64} = DT_INT32" ) |
316 | .Attr("T: {bfloat16, half, float, double}" ) |
317 | .SetShapeFn(shape_inference::UnknownShape); |
318 | |
319 | REGISTER_OP("RiscRem" ) |
320 | .Input("x: T" ) |
321 | .Input("y: T" ) |
322 | .Output("z: T" ) |
323 | .Attr("T: {bfloat16, half, float, double}" ) |
324 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn); |
325 | |
326 | // TODO(b/178234771): change shape function. |
327 | REGISTER_OP("RiscReshape" ) |
328 | .Input("tensor: T" ) |
329 | .Input("shape: Tshape" ) |
330 | .Output("output: T" ) |
331 | .Attr("T: {bfloat16, half, float, double}" ) |
332 | .Attr("Tshape: {int32, int64} = DT_INT32" ) |
333 | .SetShapeFn(shape_inference::UnknownShape); |
334 | |
335 | // TODO(b/178234771): change shape function. |
336 | REGISTER_OP("RiscReverse" ) |
337 | .Input("tensor: T" ) |
338 | .Input("axis: Tidx" ) |
339 | .Output("output: T" ) |
340 | .Attr("Tidx: {int32, int64} = DT_INT32" ) |
341 | .Attr("T: {bfloat16, half, float, double}" ) |
342 | .SetShapeFn(shape_inference::UnknownShape); |
343 | |
344 | // TODO(b/178234771): change shape function. |
345 | REGISTER_OP("RiscScatter" ) |
346 | .Input("indices: Tindices" ) |
347 | .Input("updates: T" ) |
348 | .Input("shape: Tindices" ) |
349 | .Output("output: T" ) |
350 | .Attr("T: {bfloat16, half, float, double}" ) |
351 | .Attr("Tindices: {int32, int64}" ) |
352 | .SetShapeFn(shape_inference::UnknownShape); |
353 | |
354 | // TODO(b/178234771): change shape function. |
355 | REGISTER_OP("RiscShape" ) |
356 | .Input("input: T" ) |
357 | .Output("output: out_type" ) |
358 | .Attr("T: {bfloat16, half, float, double}" ) |
359 | .Attr("out_type: {int32, int64} = DT_INT32" ) |
360 | .SetShapeFn(shape_inference::UnknownShape); |
361 | |
362 | REGISTER_OP("RiscSign" ) |
363 | .Input("x: T" ) |
364 | .Output("y: T" ) |
365 | .Attr("T: {bfloat16, half, float, double}" ) |
366 | .SetShapeFn(shape_inference::UnchangedShape); |
367 | |
368 | REGISTER_OP("RiscSlice" ) |
369 | .Input("input: T" ) |
370 | .Input("begin: Index" ) |
371 | .Input("size: Index" ) |
372 | .Output("output: T" ) |
373 | .Attr("T: {bfloat16, half, float, double}" ) |
374 | .Attr("Index: {int32,int64}" ) |
375 | .SetShapeFn(shape_inference::SliceShape); |
376 | |
377 | REGISTER_OP("RiscSort" ) |
378 | .Input("input: T" ) |
379 | .Input("axis: Index" ) |
380 | .Output("output: T" ) |
381 | .Attr("Index: {int32,int64} = DT_INT32" ) |
382 | .Attr("T: {bfloat16, half, float, double}" ) |
383 | .Attr("direction: {'ASCENDING', 'DESCENDING'}" ) |
384 | .SetShapeFn(shape_inference::UnchangedShape); |
385 | |
386 | // TODO(b/178234771): change shape function. |
387 | REGISTER_OP("RiscSqueeze" ) |
388 | .Input("input: T" ) |
389 | .Output("output: T" ) |
390 | .Attr("T: type" ) |
391 | .Attr("squeeze_dims: list(int) >= 0 = []" ) |
392 | .SetShapeFn(shape_inference::UnknownShape); |
393 | |
394 | REGISTER_OP("RiscSub" ) |
395 | .Input("x: T" ) |
396 | .Input("y: T" ) |
397 | .Output("z: T" ) |
398 | .Attr("T: {bfloat16, half, float, double}" ) |
399 | .SetShapeFn(RiscBinaryNonBroadcastOpShapeFn); |
400 | |
401 | // TODO(b/178234771): change shape function. |
402 | REGISTER_OP("RiscTranspose" ) |
403 | .Input("x: T" ) |
404 | .Input("perm: Tperm" ) |
405 | .Output("y: T" ) |
406 | .Attr("T: type" ) |
407 | .Attr("Tperm: {int32, int64} = DT_INT32" ) |
408 | .SetShapeFn(shape_inference::UnknownShape); |
409 | |
410 | // TODO(b/178234771): change shape function. |
411 | REGISTER_OP("RiscTriangularSolve" ) |
412 | .Input("matrix: T" ) |
413 | .Input("rhs: T" ) |
414 | .Output("output: T" ) |
415 | .Attr("lower: bool = True" ) |
416 | .Attr("adjoint: bool = False" ) |
417 | .Attr("T: {bfloat16, half, float, double}" ) |
418 | .SetShapeFn(shape_inference::UnknownShape); |
419 | |
420 | // TODO(b/178234771): retire this. |
421 | REGISTER_OP("RiscUnary" ) |
422 | .Input("x: T" ) |
423 | .Output("y: T" ) |
424 | .Attr( |
425 | "op_type: {'ABL', 'CEIL', 'COS', 'EXP', 'FLOOR', 'IMAG', 'LOG', 'NEG', " |
426 | "'REAL', 'SIGN'}" ) |
427 | .Attr("T: {bfloat16, half, float, double}" ) |
428 | .SetShapeFn(shape_inference::UnchangedShape); |
429 | |
430 | // TODO(b/178234771): change shape function. |
431 | REGISTER_OP("RiscWhile" ) |
432 | .Input("input: T" ) |
433 | .Output("output: T" ) |
434 | .Attr("T: list(type) >= 0" ) |
435 | .Attr("cond: func" ) |
436 | .Attr("body: func" ) |
437 | .Attr("output_shapes: list(shape) = []" ) |
438 | .Attr("parallel_iterations: int = 10" ) |
439 | .SetIsStateful() |
440 | .SetShapeFn(shape_inference::UnknownShape); |
441 | |
442 | } // namespace tensorflow |
443 | |