1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19#include "tensorflow/core/framework/types.pb.h"
20#include "tensorflow/core/platform/errors.h"
21
22namespace tensorflow {
23
24using shape_inference::DimensionHandle;
25using shape_inference::InferenceContext;
26using shape_inference::ShapeHandle;
27
28namespace {
29
30Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) {
31 ShapeHandle unused;
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
33 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // a_shape
35 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused)); // b_indices
36 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused)); // b_values
37 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused)); // b_shape
38 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
39 InferenceContext::kUnknownDim));
40 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
41 return OkStatus();
42}
43
44} // namespace
45
46REGISTER_OP("SparseAddGrad")
47 .Input("backprop_val_grad: T")
48 .Input("a_indices: int64")
49 .Input("b_indices: int64")
50 .Input("sum_indices: int64")
51 .Output("a_val_grad: T")
52 .Output("b_val_grad: T")
53 .Attr("T: numbertype")
54 .SetShapeFn([](InferenceContext* c) {
55 ShapeHandle a_indices;
56 ShapeHandle b_indices;
57 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices));
58 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices));
59 c->set_output(0, c->Vector(c->Dim(a_indices, 0)));
60 c->set_output(1, c->Vector(c->Dim(b_indices, 0)));
61 return OkStatus();
62 });
63
64REGISTER_OP("SparseAdd")
65 .Input("a_indices: int64")
66 .Input("a_values: T")
67 .Input("a_shape: int64")
68 .Input("b_indices: int64")
69 .Input("b_values: T")
70 .Input("b_shape: int64")
71 .Input("thresh: Treal")
72 .Output("sum_indices: int64")
73 .Output("sum_values: T")
74 .Output("sum_shape: int64")
75 .Attr("T: numbertype")
76 .Attr("Treal: realnumbertype")
77 .SetShapeFn([](InferenceContext* c) {
78 ShapeHandle a_shape;
79 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape));
80 c->set_output(
81 0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0)));
82 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
83 c->set_output(2, a_shape);
84 return OkStatus();
85 });
86
87REGISTER_OP("SparseTensorDenseMatMul")
88 .Input("a_indices: Tindices")
89 .Input("a_values: T")
90 .Input("a_shape: int64")
91 .Input("b: T")
92 .Output("product: T")
93 .Attr("T: type")
94 .Attr("Tindices: {int32,int64} = DT_INT64")
95 .Attr("adjoint_a: bool = false")
96 .Attr("adjoint_b: bool = false")
97 .SetShapeFn([](InferenceContext* c) {
98 DimensionHandle unused_dim;
99 ShapeHandle unused;
100 ShapeHandle b;
101 ShapeHandle a_shape;
102 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices
103 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values
104 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape));
105 TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape));
106 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b));
107
108 bool adjoint_a;
109 bool adjoint_b;
110 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a", &adjoint_a));
111 TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b", &adjoint_b));
112
113 DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1);
114 DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0);
115 DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1);
116 DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0);
117 TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim));
118 c->set_output(0, c->Matrix(output_left, output_right));
119 return OkStatus();
120 });
121
122REGISTER_OP("SerializeSparse")
123 .Input("sparse_indices: int64")
124 .Input("sparse_values: T")
125 .Input("sparse_shape: int64")
126 .Attr("T: type")
127 .Output("serialized_sparse: out_type")
128 .Attr("out_type: {string, variant} = DT_STRING")
129 .SetShapeFn([](InferenceContext* c) {
130 ShapeHandle unused;
131 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
132 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
133 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
134 c->set_output(0, c->Vector(3));
135 return OkStatus();
136 });
137
138REGISTER_OP("SerializeManySparse")
139 .Input("sparse_indices: int64")
140 .Input("sparse_values: T")
141 .Input("sparse_shape: int64")
142 .Attr("T: type")
143 .Output("serialized_sparse: out_type")
144 .Attr("out_type: {string, variant} = DT_STRING")
145 .SetShapeFn([](InferenceContext* c) {
146 ShapeHandle unused;
147 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
148 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
149 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
150 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3));
151 return OkStatus();
152 });
153
154REGISTER_OP("DeserializeSparse")
155 .Input("serialized_sparse: Tserialized")
156 .Output("sparse_indices: int64")
157 .Output("sparse_values: dtype")
158 .Output("sparse_shape: int64")
159 .Attr("dtype: type")
160 .Attr("Tserialized: {string, variant} = DT_STRING")
161 .SetShapeFn([](InferenceContext* c) {
162 // serialized sparse is [?, ..., ?, 3] vector.
163 ShapeHandle unused_shape;
164 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused_shape));
165 DimensionHandle unused;
166 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused));
167 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
168 InferenceContext::kUnknownDim));
169 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
170 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
171 return OkStatus();
172 });
173
174REGISTER_OP("DeserializeManySparse")
175 .Input("serialized_sparse: string")
176 .Output("sparse_indices: int64")
177 .Output("sparse_values: dtype")
178 .Output("sparse_shape: int64")
179 .Attr("dtype: type")
180 .SetShapeFn([](InferenceContext* c) {
181 // serialized sparse is [?,3] matrix.
182 ShapeHandle serialized_sparse;
183 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse));
184 DimensionHandle unused;
185 TF_RETURN_IF_ERROR(
186 c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused));
187
188 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
189 InferenceContext::kUnknownDim));
190 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
191 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
192 return OkStatus();
193 });
194
195REGISTER_OP("SparseToDense")
196 .Input("sparse_indices: Tindices")
197 .Input("output_shape: Tindices")
198 .Input("sparse_values: T")
199 .Input("default_value: T")
200 .Attr("validate_indices: bool = true")
201 .Attr("T: type")
202 .Output("dense: T")
203 .Attr("Tindices: {int32, int64}")
204 .SetShapeFn([](InferenceContext* c) {
205 ShapeHandle out;
206 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out));
207 c->set_output(0, out);
208 return OkStatus();
209 });
210
211REGISTER_OP("SparseConcat")
212 .Input("indices: N * int64")
213 .Input("values: N * T")
214 .Input("shapes: N * int64")
215 .Output("output_indices: int64")
216 .Output("output_values: T")
217 .Output("output_shape: int64")
218 .Attr("concat_dim: int")
219 .Attr("N: int >= 2")
220 .Attr("T: type")
221 .SetShapeFn([](InferenceContext* c) {
222 // These accumulates the sum.
223 DimensionHandle output_row_count = c->MakeDim(0ll);
224
225 // These are only merged.
226 DimensionHandle output_ind_cols = c->UnknownDim();
227 ShapeHandle output_shape = c->UnknownShape();
228
229 const int n = c->num_inputs() / 3;
230 for (int i = 0; i < n; i++) {
231 ShapeHandle ind;
232 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind));
233 ShapeHandle val;
234 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val));
235 ShapeHandle shape;
236 TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape));
237
238 // Add to output_ind_rows.
239 DimensionHandle num_dim;
240 TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim));
241 TF_RETURN_IF_ERROR(
242 c->Add(output_row_count, num_dim, &output_row_count));
243
244 // Merge into output_ind_cols and output_shape.
245 TF_RETURN_IF_ERROR(
246 c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols));
247 TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape));
248 }
249
250 c->set_output(0, c->Matrix(output_row_count, output_ind_cols));
251 c->set_output(1, c->Vector(output_row_count));
252 c->set_output(2, output_shape);
253 return OkStatus();
254 });
255
256REGISTER_OP("SparseCross")
257 .Input("indices: N * int64")
258 .Input("values: sparse_types")
259 .Input("shapes: N * int64")
260 .Input("dense_inputs: dense_types")
261 .Output("output_indices: int64")
262 .Output("output_values: out_type")
263 .Output("output_shape: int64")
264 .Attr("N: int >= 0")
265 .Attr("hashed_output: bool")
266 .Attr("num_buckets: int >= 0")
267 .Attr("hash_key: int")
268 .Attr("sparse_types: list({int64, string}) >= 0")
269 .Attr("dense_types: list({int64, string}) >= 0")
270 .Attr("out_type: {int64, string}")
271 .Attr("internal_type: {int64, string}")
272 .SetShapeFn([](shape_inference::InferenceContext* c) {
273 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
274 c->set_output(1, c->Vector(c->UnknownDim()));
275 c->set_output(2, c->Vector(2));
276 return OkStatus();
277 });
278
279REGISTER_OP("SparseCrossV2")
280 .Input("indices: N * int64")
281 .Input("values: sparse_types")
282 .Input("shapes: N * int64")
283 .Input("dense_inputs: dense_types")
284 .Input("sep: string")
285 .Output("output_indices: int64")
286 .Output("output_values: string")
287 .Output("output_shape: int64")
288 .Attr("N: int >= 0")
289 .Attr("sparse_types: list({int64, string}) >= 0")
290 .Attr("dense_types: list({int64, string}) >= 0")
291 .SetShapeFn([](shape_inference::InferenceContext* c) {
292 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
293 c->set_output(1, c->Vector(c->UnknownDim()));
294 c->set_output(2, c->Vector(2));
295 return OkStatus();
296 });
297
298REGISTER_OP("SparseCrossHashed")
299 .Input("indices: N * int64")
300 .Input("values: sparse_types")
301 .Input("shapes: N * int64")
302 .Input("dense_inputs: dense_types")
303 .Input("num_buckets: int64")
304 .Input("strong_hash: bool")
305 .Input("salt: int64")
306 .Output("output_indices: int64")
307 .Output("output_values: int64")
308 .Output("output_shape: int64")
309 .Attr("N: int >= 0")
310 .Attr("sparse_types: list({int64, string}) >= 0")
311 .Attr("dense_types: list({int64, string}) >= 0")
312 .SetShapeFn([](shape_inference::InferenceContext* c) {
313 c->set_output(0, c->Matrix(c->UnknownDim(), 2));
314 c->set_output(1, c->Vector(c->UnknownDim()));
315 c->set_output(2, c->Vector(2));
316 return OkStatus();
317 });
318
319REGISTER_OP("SparseSplit")
320 .Input("split_dim: int64")
321 .Input("indices: int64")
322 .Input("values: T")
323 .Input("shape: int64")
324 .Output("output_indices: num_split * int64")
325 .Output("output_values: num_split * T")
326 .Output("output_shape: num_split * int64")
327 .Attr("num_split: int >= 1")
328 .Attr("T: type")
329 .SetShapeFn([](InferenceContext* c) {
330 ShapeHandle input_shape = c->input(3);
331 ShapeHandle output_indices =
332 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
333 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
334 ShapeHandle output_shape = input_shape;
335
336 // Copy the outputs into the output ranges.
337 int num_splits = c->num_outputs() / 3;
338 int out_idx = 0;
339 for (int i = 0; i < num_splits; ++i)
340 c->set_output(out_idx++, output_indices);
341 for (int i = 0; i < num_splits; ++i)
342 c->set_output(out_idx++, output_values);
343 for (int i = 0; i < num_splits; ++i)
344 c->set_output(out_idx++, output_shape);
345 return OkStatus();
346 });
347
348REGISTER_OP("SparseSliceGrad")
349 .Input("backprop_val_grad: T")
350 .Input("input_indices: int64")
351 .Input("input_start: int64")
352 .Input("output_indices: int64")
353 .Output("val_grad: T")
354 .Attr("T: numbertype")
355 .SetShapeFn([](InferenceContext* c) {
356 ShapeHandle indices;
357 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices));
358 c->set_output(0, c->Vector(c->Dim(indices, 0)));
359 return OkStatus();
360 });
361
362REGISTER_OP("SparseSlice")
363 .Input("indices: int64")
364 .Input("values: T")
365 .Input("shape: int64")
366 .Input("start: int64")
367 .Input("size: int64")
368 .Output("output_indices: int64")
369 .Output("output_values: T")
370 .Output("output_shape: int64")
371 .Attr("T: type")
372 .SetShapeFn([](InferenceContext* c) {
373 ShapeHandle input_shape = c->input(2);
374 ShapeHandle output_indices =
375 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
376 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
377 ShapeHandle output_shape = input_shape;
378
379 c->set_output(0, output_indices);
380 c->set_output(1, output_values);
381 c->set_output(2, output_shape);
382 return OkStatus();
383 });
384
385REGISTER_OP("SparseReorder")
386 .Input("input_indices: int64")
387 .Input("input_values: T")
388 .Input("input_shape: int64")
389 .Output("output_indices: int64")
390 .Output("output_values: T")
391 .Attr("T: type")
392 .SetShapeFn([](InferenceContext* c) {
393 ShapeHandle indices;
394 ShapeHandle values;
395 ShapeHandle unused;
396
397 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
398 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values));
399 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
400
401 c->set_output(0, indices);
402 c->set_output(1, values);
403 return OkStatus();
404 });
405
406REGISTER_OP("SparseReshape")
407 .Input("input_indices: int64")
408 .Input("input_shape: int64")
409 .Input("new_shape: int64")
410 .Output("output_indices: int64")
411 .Output("output_shape: int64")
412 .SetShapeFn([](InferenceContext* c) {
413 ShapeHandle indices;
414 ShapeHandle unused;
415 ShapeHandle new_shape;
416
417 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices));
418 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
419 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape));
420
421 c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0)));
422 c->set_output(1, new_shape);
423 return OkStatus();
424 });
425
426REGISTER_OP("SparseTensorDenseAdd")
427 .Input("a_indices: Tindices")
428 .Input("a_values: T")
429 .Input("a_shape: Tindices")
430 .Input("b: T")
431 .Output("output: T")
432 .Attr("T: numbertype")
433 .Attr("Tindices: {int32, int64}")
434 .SetShapeFn([](InferenceContext* c) {
435 c->set_output(0, c->input(3));
436 return OkStatus();
437 });
438
439REGISTER_OP("SparseReduceMax")
440 .Input("input_indices: int64")
441 .Input("input_values: T")
442 .Input("input_shape: int64")
443 .Input("reduction_axes: int32")
444 .Attr("keep_dims: bool = False")
445 .Output("output: T")
446 .Attr("T: realnumbertype")
447 .SetShapeFn(shape_inference::SparseReduceShapeFn);
448
449REGISTER_OP("SparseReduceMaxSparse")
450 .Input("input_indices: int64")
451 .Input("input_values: T")
452 .Input("input_shape: int64")
453 .Input("reduction_axes: int32")
454 .Attr("keep_dims: bool = False")
455 .Output("output_indices: int64")
456 .Output("output_values: T")
457 .Output("output_shape: int64")
458 .Attr("T: realnumbertype")
459 .SetShapeFn(shape_inference::UnknownShape);
460
461REGISTER_OP("SparseReduceSum")
462 .Input("input_indices: int64")
463 .Input("input_values: T")
464 .Input("input_shape: int64")
465 .Input("reduction_axes: int32")
466 .Attr("keep_dims: bool = False")
467 .Output("output: T")
468 .Attr("T: numbertype")
469 .SetShapeFn(shape_inference::SparseReduceShapeFn);
470
471REGISTER_OP("SparseReduceSumSparse")
472 .Input("input_indices: int64")
473 .Input("input_values: T")
474 .Input("input_shape: int64")
475 .Input("reduction_axes: int32")
476 .Attr("keep_dims: bool = False")
477 .Output("output_indices: int64")
478 .Output("output_values: T")
479 .Output("output_shape: int64")
480 .Attr("T: numbertype")
481 .SetShapeFn(shape_inference::UnknownShape);
482
483#define SPARSE_DENSE_CWISE_SIGNATURE() \
484 Input("sp_indices: int64") \
485 .Input("sp_values: T") \
486 .Input("sp_shape: int64") \
487 .Input("dense: T") \
488 .Output("output: T") \
489 .Attr("T: numbertype") \
490 .SetShapeFn([](InferenceContext* c) { \
491 ShapeHandle input; \
492 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \
493 c->set_output(0, c->Vector(c->Dim(input, 0))); \
494 return OkStatus(); \
495 })
496
497REGISTER_OP("SparseDenseCwiseMul").SPARSE_DENSE_CWISE_SIGNATURE();
498
499REGISTER_OP("SparseDenseCwiseDiv").SPARSE_DENSE_CWISE_SIGNATURE();
500
501REGISTER_OP("SparseDenseCwiseAdd").SPARSE_DENSE_CWISE_SIGNATURE();
502
503#undef SPARSE_DENSE_CWISE_SIGNATURE
504
505REGISTER_OP("SparseSoftmax")
506 .Input("sp_indices: int64")
507 .Input("sp_values: T")
508 .Input("sp_shape: int64")
509 .Output("output: T")
510 .Attr("T: {half, float, double}")
511 .SetShapeFn([](InferenceContext* c) {
512 ShapeHandle unused;
513 ShapeHandle values;
514 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // sp_indices
515 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); // sp_values
516 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
517 c->set_output(0, values);
518 return OkStatus();
519 });
520
521REGISTER_OP("SparseSparseMaximum")
522 .Input("a_indices: int64")
523 .Input("a_values: T")
524 .Input("a_shape: int64")
525 .Input("b_indices: int64")
526 .Input("b_values: T")
527 .Input("b_shape: int64")
528 .Output("output_indices: int64")
529 .Output("output_values: T")
530 .Attr("T: realnumbertype")
531 .SetShapeFn(SparseSparseMinOrMaxShapeFn);
532
533REGISTER_OP("SparseSparseMinimum")
534 .Input("a_indices: int64")
535 .Input("a_values: T")
536 .Input("a_shape: int64")
537 .Input("b_indices: int64")
538 .Input("b_values: T")
539 .Input("b_shape: int64")
540 .Output("output_indices: int64")
541 .Output("output_values: T")
542 .Attr("T: numbertype")
543 .SetShapeFn(SparseSparseMinOrMaxShapeFn);
544
545REGISTER_OP("AddSparseToTensorsMap")
546 .Input("sparse_indices: int64")
547 .Input("sparse_values: T")
548 .Input("sparse_shape: int64")
549 .Output("sparse_handle: int64")
550 .Attr("T: type")
551 .Attr("container: string = ''")
552 .Attr("shared_name: string = ''")
553 .SetIsStateful()
554 .SetShapeFn([](InferenceContext* c) {
555 ShapeHandle unused;
556 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
557 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
558 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
559 c->set_output(0, c->Scalar());
560 return OkStatus();
561 });
562
563REGISTER_OP("AddManySparseToTensorsMap")
564 .Input("sparse_indices: int64")
565 .Input("sparse_values: T")
566 .Input("sparse_shape: int64")
567 .Output("sparse_handles: int64")
568 .Attr("T: type")
569 .Attr("container: string = ''")
570 .Attr("shared_name: string = ''")
571 .SetIsStateful()
572 .SetShapeFn([](InferenceContext* c) {
573 ShapeHandle unused;
574 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused));
575 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
576 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
577 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
578 return OkStatus();
579 });
580
581REGISTER_OP("TakeManySparseFromTensorsMap")
582 .Input("sparse_handles: int64")
583 .Output("sparse_indices: int64")
584 .Output("sparse_values: dtype")
585 .Output("sparse_shape: int64")
586 .Attr("dtype: type")
587 .Attr("container: string = ''")
588 .Attr("shared_name: string = ''")
589 .SetIsStateful()
590 .SetShapeFn([](InferenceContext* c) {
591 // serialized sparse is [?,1] matrix.
592 ShapeHandle sparse_handles;
593 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles));
594
595 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim,
596 InferenceContext::kUnknownDim));
597 c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
598 c->set_output(2, c->Vector(InferenceContext::kUnknownDim));
599 return OkStatus();
600 });
601
602REGISTER_OP("SparseFillEmptyRows")
603 .Input("indices: int64")
604 .Input("values: T")
605 .Input("dense_shape: int64")
606 .Input("default_value: T")
607 .Output("output_indices: int64")
608 .Output("output_values: T")
609 .Output("empty_row_indicator: bool")
610 .Output("reverse_index_map: int64")
611 .Attr("T: type")
612 .SetShapeFn([](InferenceContext* c) {
613 ShapeHandle input_indices = c->input(0);
614 TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices));
615 ShapeHandle input_values = c->input(1);
616 TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values));
617 ShapeHandle input_shape = c->input(2);
618 TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape));
619 ShapeHandle default_value = c->input(3);
620 TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value));
621 DimensionHandle N = c->Dim(input_indices, 0);
622 TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N));
623 DimensionHandle unused_dim;
624 TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1),
625 c->Dim(input_shape, 0), &unused_dim));
626 if (c->Value(c->NumElements(input_shape)) == 0)
627 return errors::InvalidArgument("dense_shape must not be empty");
628 ShapeHandle output_indices =
629 c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape));
630 ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim);
631 ShapeHandle constant_input_shape;
632 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape));
633 ShapeHandle empty_row_indicator =
634 c->Vector(c->Dim(constant_input_shape, 0));
635 ShapeHandle reverse_index_map = c->Vector(N);
636 c->set_output(0, output_indices);
637 c->set_output(1, output_values);
638 c->set_output(2, empty_row_indicator);
639 c->set_output(3, reverse_index_map);
640 return OkStatus();
641 });
642
643REGISTER_OP("SparseFillEmptyRowsGrad")
644 .Input("reverse_index_map: int64")
645 .Input("grad_values: T")
646 .Output("d_values: T")
647 .Output("d_default_value: T")
648 .Attr("T: type")
649 .SetShapeFn([](InferenceContext* c) {
650 ShapeHandle reverse_index_map = c->input(0);
651 TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map));
652 ShapeHandle grad_values = c->input(1);
653 TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values));
654 c->set_output(0, reverse_index_map);
655 c->set_output(1, c->Scalar());
656 return OkStatus();
657 });
658
659} // namespace tensorflow
660