1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | #include "tensorflow/core/framework/types.pb.h" |
20 | #include "tensorflow/core/platform/errors.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | using shape_inference::DimensionHandle; |
25 | using shape_inference::InferenceContext; |
26 | using shape_inference::ShapeHandle; |
27 | |
28 | namespace { |
29 | |
30 | Status SparseSparseMinOrMaxShapeFn(InferenceContext* c) { |
31 | ShapeHandle unused; |
32 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices |
33 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values |
34 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); // a_shape |
35 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &unused)); // b_indices |
36 | TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 1, &unused)); // b_values |
37 | TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 1, &unused)); // b_shape |
38 | c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, |
39 | InferenceContext::kUnknownDim)); |
40 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
41 | return OkStatus(); |
42 | } |
43 | |
44 | } // namespace |
45 | |
46 | REGISTER_OP("SparseAddGrad" ) |
47 | .Input("backprop_val_grad: T" ) |
48 | .Input("a_indices: int64" ) |
49 | .Input("b_indices: int64" ) |
50 | .Input("sum_indices: int64" ) |
51 | .Output("a_val_grad: T" ) |
52 | .Output("b_val_grad: T" ) |
53 | .Attr("T: numbertype" ) |
54 | .SetShapeFn([](InferenceContext* c) { |
55 | ShapeHandle a_indices; |
56 | ShapeHandle b_indices; |
57 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &a_indices)); |
58 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &b_indices)); |
59 | c->set_output(0, c->Vector(c->Dim(a_indices, 0))); |
60 | c->set_output(1, c->Vector(c->Dim(b_indices, 0))); |
61 | return OkStatus(); |
62 | }); |
63 | |
64 | REGISTER_OP("SparseAdd" ) |
65 | .Input("a_indices: int64" ) |
66 | .Input("a_values: T" ) |
67 | .Input("a_shape: int64" ) |
68 | .Input("b_indices: int64" ) |
69 | .Input("b_values: T" ) |
70 | .Input("b_shape: int64" ) |
71 | .Input("thresh: Treal" ) |
72 | .Output("sum_indices: int64" ) |
73 | .Output("sum_values: T" ) |
74 | .Output("sum_shape: int64" ) |
75 | .Attr("T: numbertype" ) |
76 | .Attr("Treal: realnumbertype" ) |
77 | .SetShapeFn([](InferenceContext* c) { |
78 | ShapeHandle a_shape; |
79 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &a_shape)); |
80 | c->set_output( |
81 | 0, c->Matrix(InferenceContext::kUnknownDim, c->Dim(a_shape, 0))); |
82 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
83 | c->set_output(2, a_shape); |
84 | return OkStatus(); |
85 | }); |
86 | |
87 | REGISTER_OP("SparseTensorDenseMatMul" ) |
88 | .Input("a_indices: Tindices" ) |
89 | .Input("a_values: T" ) |
90 | .Input("a_shape: int64" ) |
91 | .Input("b: T" ) |
92 | .Output("product: T" ) |
93 | .Attr("T: type" ) |
94 | .Attr("Tindices: {int32,int64} = DT_INT64" ) |
95 | .Attr("adjoint_a: bool = false" ) |
96 | .Attr("adjoint_b: bool = false" ) |
97 | .SetShapeFn([](InferenceContext* c) { |
98 | DimensionHandle unused_dim; |
99 | ShapeHandle unused; |
100 | ShapeHandle b; |
101 | ShapeHandle a_shape; |
102 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // a_indices |
103 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // a_values |
104 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &a_shape)); |
105 | TF_RETURN_IF_ERROR(c->WithRank(a_shape, 2, &a_shape)); |
106 | TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 2, &b)); |
107 | |
108 | bool adjoint_a; |
109 | bool adjoint_b; |
110 | TF_RETURN_IF_ERROR(c->GetAttr("adjoint_a" , &adjoint_a)); |
111 | TF_RETURN_IF_ERROR(c->GetAttr("adjoint_b" , &adjoint_b)); |
112 | |
113 | DimensionHandle output_right = c->Dim(b, adjoint_b ? 0 : 1); |
114 | DimensionHandle output_left = c->Dim(a_shape, adjoint_a ? 1 : 0); |
115 | DimensionHandle inner_left = c->Dim(a_shape, adjoint_a ? 0 : 1); |
116 | DimensionHandle inner_right = c->Dim(b, adjoint_b ? 1 : 0); |
117 | TF_RETURN_IF_ERROR(c->Merge(inner_left, inner_right, &unused_dim)); |
118 | c->set_output(0, c->Matrix(output_left, output_right)); |
119 | return OkStatus(); |
120 | }); |
121 | |
122 | REGISTER_OP("SerializeSparse" ) |
123 | .Input("sparse_indices: int64" ) |
124 | .Input("sparse_values: T" ) |
125 | .Input("sparse_shape: int64" ) |
126 | .Attr("T: type" ) |
127 | .Output("serialized_sparse: out_type" ) |
128 | .Attr("out_type: {string, variant} = DT_STRING" ) |
129 | .SetShapeFn([](InferenceContext* c) { |
130 | ShapeHandle unused; |
131 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); |
132 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
133 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
134 | c->set_output(0, c->Vector(3)); |
135 | return OkStatus(); |
136 | }); |
137 | |
138 | REGISTER_OP("SerializeManySparse" ) |
139 | .Input("sparse_indices: int64" ) |
140 | .Input("sparse_values: T" ) |
141 | .Input("sparse_shape: int64" ) |
142 | .Attr("T: type" ) |
143 | .Output("serialized_sparse: out_type" ) |
144 | .Attr("out_type: {string, variant} = DT_STRING" ) |
145 | .SetShapeFn([](InferenceContext* c) { |
146 | ShapeHandle unused; |
147 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); |
148 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
149 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
150 | c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 3)); |
151 | return OkStatus(); |
152 | }); |
153 | |
154 | REGISTER_OP("DeserializeSparse" ) |
155 | .Input("serialized_sparse: Tserialized" ) |
156 | .Output("sparse_indices: int64" ) |
157 | .Output("sparse_values: dtype" ) |
158 | .Output("sparse_shape: int64" ) |
159 | .Attr("dtype: type" ) |
160 | .Attr("Tserialized: {string, variant} = DT_STRING" ) |
161 | .SetShapeFn([](InferenceContext* c) { |
162 | // serialized sparse is [?, ..., ?, 3] vector. |
163 | ShapeHandle unused_shape; |
164 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &unused_shape)); |
165 | DimensionHandle unused; |
166 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(c->input(0), -1), 3, &unused)); |
167 | c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, |
168 | InferenceContext::kUnknownDim)); |
169 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
170 | c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); |
171 | return OkStatus(); |
172 | }); |
173 | |
174 | REGISTER_OP("DeserializeManySparse" ) |
175 | .Input("serialized_sparse: string" ) |
176 | .Output("sparse_indices: int64" ) |
177 | .Output("sparse_values: dtype" ) |
178 | .Output("sparse_shape: int64" ) |
179 | .Attr("dtype: type" ) |
180 | .SetShapeFn([](InferenceContext* c) { |
181 | // serialized sparse is [?,3] matrix. |
182 | ShapeHandle serialized_sparse; |
183 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &serialized_sparse)); |
184 | DimensionHandle unused; |
185 | TF_RETURN_IF_ERROR( |
186 | c->WithValue(c->Dim(serialized_sparse, 1), 3, &unused)); |
187 | |
188 | c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, |
189 | InferenceContext::kUnknownDim)); |
190 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
191 | c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); |
192 | return OkStatus(); |
193 | }); |
194 | |
195 | REGISTER_OP("SparseToDense" ) |
196 | .Input("sparse_indices: Tindices" ) |
197 | .Input("output_shape: Tindices" ) |
198 | .Input("sparse_values: T" ) |
199 | .Input("default_value: T" ) |
200 | .Attr("validate_indices: bool = true" ) |
201 | .Attr("T: type" ) |
202 | .Output("dense: T" ) |
203 | .Attr("Tindices: {int32, int64}" ) |
204 | .SetShapeFn([](InferenceContext* c) { |
205 | ShapeHandle out; |
206 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &out)); |
207 | c->set_output(0, out); |
208 | return OkStatus(); |
209 | }); |
210 | |
211 | REGISTER_OP("SparseConcat" ) |
212 | .Input("indices: N * int64" ) |
213 | .Input("values: N * T" ) |
214 | .Input("shapes: N * int64" ) |
215 | .Output("output_indices: int64" ) |
216 | .Output("output_values: T" ) |
217 | .Output("output_shape: int64" ) |
218 | .Attr("concat_dim: int" ) |
219 | .Attr("N: int >= 2" ) |
220 | .Attr("T: type" ) |
221 | .SetShapeFn([](InferenceContext* c) { |
222 | // These accumulates the sum. |
223 | DimensionHandle output_row_count = c->MakeDim(0ll); |
224 | |
225 | // These are only merged. |
226 | DimensionHandle output_ind_cols = c->UnknownDim(); |
227 | ShapeHandle output_shape = c->UnknownShape(); |
228 | |
229 | const int n = c->num_inputs() / 3; |
230 | for (int i = 0; i < n; i++) { |
231 | ShapeHandle ind; |
232 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &ind)); |
233 | ShapeHandle val; |
234 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i + n), 1, &val)); |
235 | ShapeHandle shape; |
236 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i + 2 * n), 1, &shape)); |
237 | |
238 | // Add to output_ind_rows. |
239 | DimensionHandle num_dim; |
240 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(ind, 0), c->Dim(val, 0), &num_dim)); |
241 | TF_RETURN_IF_ERROR( |
242 | c->Add(output_row_count, num_dim, &output_row_count)); |
243 | |
244 | // Merge into output_ind_cols and output_shape. |
245 | TF_RETURN_IF_ERROR( |
246 | c->Merge(output_ind_cols, c->Dim(ind, 1), &output_ind_cols)); |
247 | TF_RETURN_IF_ERROR(c->Merge(output_shape, shape, &output_shape)); |
248 | } |
249 | |
250 | c->set_output(0, c->Matrix(output_row_count, output_ind_cols)); |
251 | c->set_output(1, c->Vector(output_row_count)); |
252 | c->set_output(2, output_shape); |
253 | return OkStatus(); |
254 | }); |
255 | |
256 | REGISTER_OP("SparseCross" ) |
257 | .Input("indices: N * int64" ) |
258 | .Input("values: sparse_types" ) |
259 | .Input("shapes: N * int64" ) |
260 | .Input("dense_inputs: dense_types" ) |
261 | .Output("output_indices: int64" ) |
262 | .Output("output_values: out_type" ) |
263 | .Output("output_shape: int64" ) |
264 | .Attr("N: int >= 0" ) |
265 | .Attr("hashed_output: bool" ) |
266 | .Attr("num_buckets: int >= 0" ) |
267 | .Attr("hash_key: int" ) |
268 | .Attr("sparse_types: list({int64, string}) >= 0" ) |
269 | .Attr("dense_types: list({int64, string}) >= 0" ) |
270 | .Attr("out_type: {int64, string}" ) |
271 | .Attr("internal_type: {int64, string}" ) |
272 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
273 | c->set_output(0, c->Matrix(c->UnknownDim(), 2)); |
274 | c->set_output(1, c->Vector(c->UnknownDim())); |
275 | c->set_output(2, c->Vector(2)); |
276 | return OkStatus(); |
277 | }); |
278 | |
279 | REGISTER_OP("SparseCrossV2" ) |
280 | .Input("indices: N * int64" ) |
281 | .Input("values: sparse_types" ) |
282 | .Input("shapes: N * int64" ) |
283 | .Input("dense_inputs: dense_types" ) |
284 | .Input("sep: string" ) |
285 | .Output("output_indices: int64" ) |
286 | .Output("output_values: string" ) |
287 | .Output("output_shape: int64" ) |
288 | .Attr("N: int >= 0" ) |
289 | .Attr("sparse_types: list({int64, string}) >= 0" ) |
290 | .Attr("dense_types: list({int64, string}) >= 0" ) |
291 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
292 | c->set_output(0, c->Matrix(c->UnknownDim(), 2)); |
293 | c->set_output(1, c->Vector(c->UnknownDim())); |
294 | c->set_output(2, c->Vector(2)); |
295 | return OkStatus(); |
296 | }); |
297 | |
298 | REGISTER_OP("SparseCrossHashed" ) |
299 | .Input("indices: N * int64" ) |
300 | .Input("values: sparse_types" ) |
301 | .Input("shapes: N * int64" ) |
302 | .Input("dense_inputs: dense_types" ) |
303 | .Input("num_buckets: int64" ) |
304 | .Input("strong_hash: bool" ) |
305 | .Input("salt: int64" ) |
306 | .Output("output_indices: int64" ) |
307 | .Output("output_values: int64" ) |
308 | .Output("output_shape: int64" ) |
309 | .Attr("N: int >= 0" ) |
310 | .Attr("sparse_types: list({int64, string}) >= 0" ) |
311 | .Attr("dense_types: list({int64, string}) >= 0" ) |
312 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
313 | c->set_output(0, c->Matrix(c->UnknownDim(), 2)); |
314 | c->set_output(1, c->Vector(c->UnknownDim())); |
315 | c->set_output(2, c->Vector(2)); |
316 | return OkStatus(); |
317 | }); |
318 | |
319 | REGISTER_OP("SparseSplit" ) |
320 | .Input("split_dim: int64" ) |
321 | .Input("indices: int64" ) |
322 | .Input("values: T" ) |
323 | .Input("shape: int64" ) |
324 | .Output("output_indices: num_split * int64" ) |
325 | .Output("output_values: num_split * T" ) |
326 | .Output("output_shape: num_split * int64" ) |
327 | .Attr("num_split: int >= 1" ) |
328 | .Attr("T: type" ) |
329 | .SetShapeFn([](InferenceContext* c) { |
330 | ShapeHandle input_shape = c->input(3); |
331 | ShapeHandle output_indices = |
332 | c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape)); |
333 | ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim); |
334 | ShapeHandle output_shape = input_shape; |
335 | |
336 | // Copy the outputs into the output ranges. |
337 | int num_splits = c->num_outputs() / 3; |
338 | int out_idx = 0; |
339 | for (int i = 0; i < num_splits; ++i) |
340 | c->set_output(out_idx++, output_indices); |
341 | for (int i = 0; i < num_splits; ++i) |
342 | c->set_output(out_idx++, output_values); |
343 | for (int i = 0; i < num_splits; ++i) |
344 | c->set_output(out_idx++, output_shape); |
345 | return OkStatus(); |
346 | }); |
347 | |
348 | REGISTER_OP("SparseSliceGrad" ) |
349 | .Input("backprop_val_grad: T" ) |
350 | .Input("input_indices: int64" ) |
351 | .Input("input_start: int64" ) |
352 | .Input("output_indices: int64" ) |
353 | .Output("val_grad: T" ) |
354 | .Attr("T: numbertype" ) |
355 | .SetShapeFn([](InferenceContext* c) { |
356 | ShapeHandle indices; |
357 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &indices)); |
358 | c->set_output(0, c->Vector(c->Dim(indices, 0))); |
359 | return OkStatus(); |
360 | }); |
361 | |
362 | REGISTER_OP("SparseSlice" ) |
363 | .Input("indices: int64" ) |
364 | .Input("values: T" ) |
365 | .Input("shape: int64" ) |
366 | .Input("start: int64" ) |
367 | .Input("size: int64" ) |
368 | .Output("output_indices: int64" ) |
369 | .Output("output_values: T" ) |
370 | .Output("output_shape: int64" ) |
371 | .Attr("T: type" ) |
372 | .SetShapeFn([](InferenceContext* c) { |
373 | ShapeHandle input_shape = c->input(2); |
374 | ShapeHandle output_indices = |
375 | c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape)); |
376 | ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim); |
377 | ShapeHandle output_shape = input_shape; |
378 | |
379 | c->set_output(0, output_indices); |
380 | c->set_output(1, output_values); |
381 | c->set_output(2, output_shape); |
382 | return OkStatus(); |
383 | }); |
384 | |
385 | REGISTER_OP("SparseReorder" ) |
386 | .Input("input_indices: int64" ) |
387 | .Input("input_values: T" ) |
388 | .Input("input_shape: int64" ) |
389 | .Output("output_indices: int64" ) |
390 | .Output("output_values: T" ) |
391 | .Attr("T: type" ) |
392 | .SetShapeFn([](InferenceContext* c) { |
393 | ShapeHandle indices; |
394 | ShapeHandle values; |
395 | ShapeHandle unused; |
396 | |
397 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices)); |
398 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); |
399 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
400 | |
401 | c->set_output(0, indices); |
402 | c->set_output(1, values); |
403 | return OkStatus(); |
404 | }); |
405 | |
406 | REGISTER_OP("SparseReshape" ) |
407 | .Input("input_indices: int64" ) |
408 | .Input("input_shape: int64" ) |
409 | .Input("new_shape: int64" ) |
410 | .Output("output_indices: int64" ) |
411 | .Output("output_shape: int64" ) |
412 | .SetShapeFn([](InferenceContext* c) { |
413 | ShapeHandle indices; |
414 | ShapeHandle unused; |
415 | ShapeHandle new_shape; |
416 | |
417 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &indices)); |
418 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
419 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &new_shape)); |
420 | |
421 | c->set_output(0, c->Matrix(c->Dim(indices, 0), c->Dim(new_shape, 0))); |
422 | c->set_output(1, new_shape); |
423 | return OkStatus(); |
424 | }); |
425 | |
426 | REGISTER_OP("SparseTensorDenseAdd" ) |
427 | .Input("a_indices: Tindices" ) |
428 | .Input("a_values: T" ) |
429 | .Input("a_shape: Tindices" ) |
430 | .Input("b: T" ) |
431 | .Output("output: T" ) |
432 | .Attr("T: numbertype" ) |
433 | .Attr("Tindices: {int32, int64}" ) |
434 | .SetShapeFn([](InferenceContext* c) { |
435 | c->set_output(0, c->input(3)); |
436 | return OkStatus(); |
437 | }); |
438 | |
439 | REGISTER_OP("SparseReduceMax" ) |
440 | .Input("input_indices: int64" ) |
441 | .Input("input_values: T" ) |
442 | .Input("input_shape: int64" ) |
443 | .Input("reduction_axes: int32" ) |
444 | .Attr("keep_dims: bool = False" ) |
445 | .Output("output: T" ) |
446 | .Attr("T: realnumbertype" ) |
447 | .SetShapeFn(shape_inference::SparseReduceShapeFn); |
448 | |
449 | REGISTER_OP("SparseReduceMaxSparse" ) |
450 | .Input("input_indices: int64" ) |
451 | .Input("input_values: T" ) |
452 | .Input("input_shape: int64" ) |
453 | .Input("reduction_axes: int32" ) |
454 | .Attr("keep_dims: bool = False" ) |
455 | .Output("output_indices: int64" ) |
456 | .Output("output_values: T" ) |
457 | .Output("output_shape: int64" ) |
458 | .Attr("T: realnumbertype" ) |
459 | .SetShapeFn(shape_inference::UnknownShape); |
460 | |
461 | REGISTER_OP("SparseReduceSum" ) |
462 | .Input("input_indices: int64" ) |
463 | .Input("input_values: T" ) |
464 | .Input("input_shape: int64" ) |
465 | .Input("reduction_axes: int32" ) |
466 | .Attr("keep_dims: bool = False" ) |
467 | .Output("output: T" ) |
468 | .Attr("T: numbertype" ) |
469 | .SetShapeFn(shape_inference::SparseReduceShapeFn); |
470 | |
471 | REGISTER_OP("SparseReduceSumSparse" ) |
472 | .Input("input_indices: int64" ) |
473 | .Input("input_values: T" ) |
474 | .Input("input_shape: int64" ) |
475 | .Input("reduction_axes: int32" ) |
476 | .Attr("keep_dims: bool = False" ) |
477 | .Output("output_indices: int64" ) |
478 | .Output("output_values: T" ) |
479 | .Output("output_shape: int64" ) |
480 | .Attr("T: numbertype" ) |
481 | .SetShapeFn(shape_inference::UnknownShape); |
482 | |
483 | #define SPARSE_DENSE_CWISE_SIGNATURE() \ |
484 | Input("sp_indices: int64") \ |
485 | .Input("sp_values: T") \ |
486 | .Input("sp_shape: int64") \ |
487 | .Input("dense: T") \ |
488 | .Output("output: T") \ |
489 | .Attr("T: numbertype") \ |
490 | .SetShapeFn([](InferenceContext* c) { \ |
491 | ShapeHandle input; \ |
492 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input)); \ |
493 | c->set_output(0, c->Vector(c->Dim(input, 0))); \ |
494 | return OkStatus(); \ |
495 | }) |
496 | |
497 | REGISTER_OP("SparseDenseCwiseMul" ).SPARSE_DENSE_CWISE_SIGNATURE(); |
498 | |
499 | REGISTER_OP("SparseDenseCwiseDiv" ).SPARSE_DENSE_CWISE_SIGNATURE(); |
500 | |
501 | REGISTER_OP("SparseDenseCwiseAdd" ).SPARSE_DENSE_CWISE_SIGNATURE(); |
502 | |
503 | #undef SPARSE_DENSE_CWISE_SIGNATURE |
504 | |
505 | REGISTER_OP("SparseSoftmax" ) |
506 | .Input("sp_indices: int64" ) |
507 | .Input("sp_values: T" ) |
508 | .Input("sp_shape: int64" ) |
509 | .Output("output: T" ) |
510 | .Attr("T: {half, float, double}" ) |
511 | .SetShapeFn([](InferenceContext* c) { |
512 | ShapeHandle unused; |
513 | ShapeHandle values; |
514 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); // sp_indices |
515 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &values)); // sp_values |
516 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
517 | c->set_output(0, values); |
518 | return OkStatus(); |
519 | }); |
520 | |
521 | REGISTER_OP("SparseSparseMaximum" ) |
522 | .Input("a_indices: int64" ) |
523 | .Input("a_values: T" ) |
524 | .Input("a_shape: int64" ) |
525 | .Input("b_indices: int64" ) |
526 | .Input("b_values: T" ) |
527 | .Input("b_shape: int64" ) |
528 | .Output("output_indices: int64" ) |
529 | .Output("output_values: T" ) |
530 | .Attr("T: realnumbertype" ) |
531 | .SetShapeFn(SparseSparseMinOrMaxShapeFn); |
532 | |
533 | REGISTER_OP("SparseSparseMinimum" ) |
534 | .Input("a_indices: int64" ) |
535 | .Input("a_values: T" ) |
536 | .Input("a_shape: int64" ) |
537 | .Input("b_indices: int64" ) |
538 | .Input("b_values: T" ) |
539 | .Input("b_shape: int64" ) |
540 | .Output("output_indices: int64" ) |
541 | .Output("output_values: T" ) |
542 | .Attr("T: numbertype" ) |
543 | .SetShapeFn(SparseSparseMinOrMaxShapeFn); |
544 | |
545 | REGISTER_OP("AddSparseToTensorsMap" ) |
546 | .Input("sparse_indices: int64" ) |
547 | .Input("sparse_values: T" ) |
548 | .Input("sparse_shape: int64" ) |
549 | .Output("sparse_handle: int64" ) |
550 | .Attr("T: type" ) |
551 | .Attr("container: string = ''" ) |
552 | .Attr("shared_name: string = ''" ) |
553 | .SetIsStateful() |
554 | .SetShapeFn([](InferenceContext* c) { |
555 | ShapeHandle unused; |
556 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); |
557 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
558 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
559 | c->set_output(0, c->Scalar()); |
560 | return OkStatus(); |
561 | }); |
562 | |
563 | REGISTER_OP("AddManySparseToTensorsMap" ) |
564 | .Input("sparse_indices: int64" ) |
565 | .Input("sparse_values: T" ) |
566 | .Input("sparse_shape: int64" ) |
567 | .Output("sparse_handles: int64" ) |
568 | .Attr("T: type" ) |
569 | .Attr("container: string = ''" ) |
570 | .Attr("shared_name: string = ''" ) |
571 | .SetIsStateful() |
572 | .SetShapeFn([](InferenceContext* c) { |
573 | ShapeHandle unused; |
574 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused)); |
575 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
576 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused)); |
577 | c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
578 | return OkStatus(); |
579 | }); |
580 | |
581 | REGISTER_OP("TakeManySparseFromTensorsMap" ) |
582 | .Input("sparse_handles: int64" ) |
583 | .Output("sparse_indices: int64" ) |
584 | .Output("sparse_values: dtype" ) |
585 | .Output("sparse_shape: int64" ) |
586 | .Attr("dtype: type" ) |
587 | .Attr("container: string = ''" ) |
588 | .Attr("shared_name: string = ''" ) |
589 | .SetIsStateful() |
590 | .SetShapeFn([](InferenceContext* c) { |
591 | // serialized sparse is [?,1] matrix. |
592 | ShapeHandle sparse_handles; |
593 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &sparse_handles)); |
594 | |
595 | c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, |
596 | InferenceContext::kUnknownDim)); |
597 | c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); |
598 | c->set_output(2, c->Vector(InferenceContext::kUnknownDim)); |
599 | return OkStatus(); |
600 | }); |
601 | |
602 | REGISTER_OP("SparseFillEmptyRows" ) |
603 | .Input("indices: int64" ) |
604 | .Input("values: T" ) |
605 | .Input("dense_shape: int64" ) |
606 | .Input("default_value: T" ) |
607 | .Output("output_indices: int64" ) |
608 | .Output("output_values: T" ) |
609 | .Output("empty_row_indicator: bool" ) |
610 | .Output("reverse_index_map: int64" ) |
611 | .Attr("T: type" ) |
612 | .SetShapeFn([](InferenceContext* c) { |
613 | ShapeHandle input_indices = c->input(0); |
614 | TF_RETURN_IF_ERROR(c->WithRank(input_indices, 2, &input_indices)); |
615 | ShapeHandle input_values = c->input(1); |
616 | TF_RETURN_IF_ERROR(c->WithRank(input_values, 1, &input_values)); |
617 | ShapeHandle input_shape = c->input(2); |
618 | TF_RETURN_IF_ERROR(c->WithRank(input_shape, 1, &input_shape)); |
619 | ShapeHandle default_value = c->input(3); |
620 | TF_RETURN_IF_ERROR(c->WithRank(default_value, 0, &default_value)); |
621 | DimensionHandle N = c->Dim(input_indices, 0); |
622 | TF_RETURN_IF_ERROR(c->Merge(N, c->Dim(input_values, 0), &N)); |
623 | DimensionHandle unused_dim; |
624 | TF_RETURN_IF_ERROR(c->Merge(c->Dim(input_indices, 1), |
625 | c->Dim(input_shape, 0), &unused_dim)); |
626 | if (c->Value(c->NumElements(input_shape)) == 0) |
627 | return errors::InvalidArgument("dense_shape must not be empty" ); |
628 | ShapeHandle output_indices = |
629 | c->Matrix(InferenceContext::kUnknownDim, c->NumElements(input_shape)); |
630 | ShapeHandle output_values = c->Vector(InferenceContext::kUnknownDim); |
631 | ShapeHandle constant_input_shape; |
632 | TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &constant_input_shape)); |
633 | ShapeHandle empty_row_indicator = |
634 | c->Vector(c->Dim(constant_input_shape, 0)); |
635 | ShapeHandle reverse_index_map = c->Vector(N); |
636 | c->set_output(0, output_indices); |
637 | c->set_output(1, output_values); |
638 | c->set_output(2, empty_row_indicator); |
639 | c->set_output(3, reverse_index_map); |
640 | return OkStatus(); |
641 | }); |
642 | |
643 | REGISTER_OP("SparseFillEmptyRowsGrad" ) |
644 | .Input("reverse_index_map: int64" ) |
645 | .Input("grad_values: T" ) |
646 | .Output("d_values: T" ) |
647 | .Output("d_default_value: T" ) |
648 | .Attr("T: type" ) |
649 | .SetShapeFn([](InferenceContext* c) { |
650 | ShapeHandle reverse_index_map = c->input(0); |
651 | TF_RETURN_IF_ERROR(c->WithRank(reverse_index_map, 1, &reverse_index_map)); |
652 | ShapeHandle grad_values = c->input(1); |
653 | TF_RETURN_IF_ERROR(c->WithRank(grad_values, 1, &grad_values)); |
654 | c->set_output(0, reverse_index_map); |
655 | c->set_output(1, c->Scalar()); |
656 | return OkStatus(); |
657 | }); |
658 | |
659 | } // namespace tensorflow |
660 | |