1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19
20namespace tensorflow {
21
22using shape_inference::DimensionHandle;
23using shape_inference::InferenceContext;
24using shape_inference::ShapeHandle;
25
26REGISTER_OP("SetSize")
27 .Input("set_indices: int64")
28 .Input("set_values: T")
29 .Input("set_shape: int64")
30 .Attr("validate_indices: bool = true")
31 .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}")
32 .Output("size: int32")
33 .SetShapeFn(shape_inference::UnknownShape);
34
35REGISTER_OP("DenseToDenseSetOperation")
36 .Input("set1: T")
37 .Input("set2: T")
38 .Attr("set_operation: string")
39 .Attr("validate_indices: bool = true")
40 .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}")
41 .Output("result_indices: int64")
42 .Output("result_values: T")
43 .Output("result_shape: int64")
44 .SetShapeFn([](InferenceContext* c) {
45 if (c->num_inputs() != 2) {
46 return errors::InvalidArgument("len(inputs) != 2.");
47 }
48 // The following should stay in sync with `ComputeDenseToDense` shape
49 // assertions in kernels/set_kernels.cc.
50 // Dimension n contains the set values to be compared, so ranks must be
51 // >= 2, and the first n-1 dimensions of inputs and output must be
52 // compatible.
53 DimensionHandle output_rank;
54 ShapeHandle input0_shape = c->input(0);
55 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape));
56 if (c->RankKnown(input0_shape)) {
57 const int32_t input0_rank = c->Rank(input0_shape);
58 ShapeHandle input1_shape = c->input(1);
59 TF_RETURN_IF_ERROR(
60 c->WithRank(input1_shape, input0_rank, &input1_shape));
61 if (c->RankKnown(input1_shape)) {
62 // If both ranks are specified, the first n-1 dims must be compatible.
63 const int32_t rank = c->Rank(input1_shape);
64 ShapeHandle group0_shape;
65 TF_RETURN_IF_ERROR(
66 c->Subshape(input0_shape, 0, rank - 1, &group0_shape));
67 ShapeHandle group1_shape;
68 TF_RETURN_IF_ERROR(
69 c->Subshape(input1_shape, 0, rank - 1, &group1_shape));
70 ShapeHandle unused_shape;
71 TF_RETURN_IF_ERROR(
72 c->Merge(group0_shape, group1_shape, &unused_shape));
73 }
74 output_rank = c->MakeDim(input0_rank);
75 } else {
76 ShapeHandle input1_shape = c->input(1);
77 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input1_shape, 2, &input1_shape));
78 if (c->RankKnown(input1_shape)) {
79 output_rank = c->MakeDim(c->Rank(input1_shape));
80 } else {
81 output_rank = c->UnknownDim();
82 }
83 }
84
85 c->set_output(0, c->Matrix(c->UnknownDim(), output_rank));
86 c->set_output(1, c->Vector(c->UnknownDim()));
87 c->set_output(2, c->Vector(output_rank));
88 return OkStatus();
89 });
90
91REGISTER_OP("DenseToSparseSetOperation")
92 .Input("set1: T")
93 .Input("set2_indices: int64")
94 .Input("set2_values: T")
95 .Input("set2_shape: int64")
96 .Attr("set_operation: string")
97 .Attr("validate_indices: bool = true")
98 .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}")
99 .Output("result_indices: int64")
100 .Output("result_values: T")
101 .Output("result_shape: int64")
102 .SetShapeFn([](InferenceContext* c) {
103 if (c->num_inputs() != 4) {
104 return errors::InvalidArgument("len(inputs) != 4.");
105 }
106 // The following should stay in sync with `ComputeDenseToSparse` shape
107 // assertions in kernels/set_kernels.cc.
108 // Ranks must be compatible, and be >= 2.
109 ShapeHandle input1_shape_shape = c->input(3);
110 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
111 c, c->input(1), c->input(2), input1_shape_shape));
112
113 DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0);
114
115 DimensionHandle output_rank_dim;
116 ShapeHandle input0_shape = c->input(0);
117 TF_RETURN_IF_ERROR(c->WithRankAtLeast(input0_shape, 2, &input0_shape));
118 if (c->RankKnown(input0_shape)) {
119 const int32_t input0_rank = c->Rank(input0_shape);
120 TF_RETURN_IF_ERROR(
121 c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim));
122 output_rank_dim = c->MakeDim(input0_rank);
123 } else if (c->ValueKnown(input1_rank_dim)) {
124 output_rank_dim = input1_rank_dim;
125 } else {
126 output_rank_dim = c->UnknownDim();
127 }
128
129 c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim));
130 c->set_output(1, c->Vector(c->UnknownDim()));
131 c->set_output(2, c->Vector(output_rank_dim));
132 return OkStatus();
133 });
134
135REGISTER_OP("SparseToSparseSetOperation")
136 .Input("set1_indices: int64")
137 .Input("set1_values: T")
138 .Input("set1_shape: int64")
139 .Input("set2_indices: int64")
140 .Input("set2_values: T")
141 .Input("set2_shape: int64")
142 .Attr("set_operation: string")
143 .Attr("validate_indices: bool = true")
144 .Attr("T: {int8, int16, int32, int64, uint8, uint16, string}")
145 .Output("result_indices: int64")
146 .Output("result_values: T")
147 .Output("result_shape: int64")
148 .SetShapeFn([](InferenceContext* c) {
149 if (c->num_inputs() != 6) {
150 return errors::InvalidArgument("len(inputs) != 6.");
151 }
152 // The following should stay in sync with `ComputeSparseToSparse` shape
153 // assertions in kernels/set_kernels.cc.
154 // Ranks must be compatible, and be >= 2.
155 ShapeHandle input0_shape_shape = c->input(2);
156 ShapeHandle input1_shape_shape = c->input(5);
157 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
158 c, c->input(0), c->input(1), input0_shape_shape));
159 TF_RETURN_IF_ERROR(shape_inference::ValidateSparseTensor(
160 c, c->input(3), c->input(4), input1_shape_shape));
161
162 DimensionHandle input0_rank_dim = c->Dim(input0_shape_shape, 0);
163 DimensionHandle input1_rank_dim = c->Dim(input1_shape_shape, 0);
164 DimensionHandle output_rank_dim;
165 if (c->ValueKnown(input0_rank_dim)) {
166 const int64_t input0_rank = c->Value(input0_rank_dim);
167 if (input0_rank < 2) {
168 return errors::InvalidArgument("Input 0, expected rank >= 2, got ",
169 input0_rank, ".");
170 }
171 TF_RETURN_IF_ERROR(
172 c->WithValue(input1_rank_dim, input0_rank, &input1_rank_dim));
173 output_rank_dim = input0_rank_dim;
174 } else if (c->ValueKnown(input1_rank_dim)) {
175 const int64_t input1_rank = c->Value(input1_rank_dim);
176 if (input1_rank < 2) {
177 return errors::InvalidArgument("Input 1, expected rank >= 2, got ",
178 input1_rank, ".");
179 }
180 output_rank_dim = input1_rank_dim;
181 } else {
182 output_rank_dim = c->UnknownDim();
183 }
184
185 c->set_output(0, c->Matrix(c->UnknownDim(), output_rank_dim));
186 c->set_output(1, c->Vector(c->UnknownDim()));
187 c->set_output(2, c->Vector(output_rank_dim));
188 return OkStatus();
189 });
190
191} // namespace tensorflow
192