1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using shape_inference::InferenceContext; |
22 | using shape_inference::ShapeHandle; |
23 | |
24 | REGISTER_OP("VariableV2" ) |
25 | .Output("ref: Ref(dtype)" ) |
26 | .Attr("shape: shape" ) |
27 | .Attr("dtype: type" ) |
28 | .Attr("container: string = ''" ) |
29 | .Attr("shared_name: string = ''" ) |
30 | .SetIsStateful() |
31 | .SetShapeFn(shape_inference::ExplicitShape); |
32 | |
33 | REGISTER_OP("Variable" ) |
34 | .Output("ref: Ref(dtype)" ) |
35 | .Attr("shape: shape" ) |
36 | .Attr("dtype: type" ) |
37 | .Attr("container: string = ''" ) |
38 | .Attr("shared_name: string = ''" ) |
39 | .SetIsStateful() |
40 | .SetShapeFn([](InferenceContext* c) { |
41 | PartialTensorShape shape; |
42 | TF_RETURN_IF_ERROR(c->GetAttr("shape" , &shape)); |
43 | |
44 | // Variable has legacy behavior where we cannot tell the difference |
45 | // between a scalar shape attribute and 'unknown shape'. So if the shape |
46 | // is a scalar, we return an unknown shape. |
47 | if (shape.dims() <= 0) { |
48 | return shape_inference::UnknownShape(c); |
49 | } |
50 | |
51 | ShapeHandle out; |
52 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out)); |
53 | c->set_output(0, out); |
54 | return OkStatus(); |
55 | }); |
56 | |
57 | REGISTER_OP("IsVariableInitialized" ) |
58 | .Input("ref: Ref(dtype)" ) |
59 | .Output("is_initialized: bool" ) |
60 | .Attr("dtype: type" ) |
61 | .SetAllowsUninitializedInput() |
62 | .SetShapeFn(shape_inference::ScalarShape); |
63 | |
64 | REGISTER_OP("TemporaryVariable" ) |
65 | .Output("ref: Ref(dtype)" ) |
66 | .Attr("shape: shape" ) |
67 | .Attr("dtype: type" ) |
68 | .Attr("var_name: string = ''" ) |
69 | .SetIsStateful() |
70 | .SetShapeFn(shape_inference::ExplicitShape); |
71 | |
72 | REGISTER_OP("DestroyTemporaryVariable" ) |
73 | .Input("ref: Ref(T)" ) |
74 | .Output("value: T" ) |
75 | .Attr("T: type" ) |
76 | .Attr("var_name: string" ) |
77 | .SetShapeFn(shape_inference::UnchangedShape); |
78 | |
79 | REGISTER_OP("Assign" ) |
80 | .Input("ref: Ref(T)" ) |
81 | .Input("value: T" ) |
82 | .Output("output_ref: Ref(T)" ) |
83 | .Attr("T: type" ) |
84 | .Attr("validate_shape: bool = true" ) |
85 | .Attr("use_locking: bool = true" ) |
86 | .SetAllowsUninitializedInput() |
87 | .SetShapeFn([](InferenceContext* c) { |
88 | bool validate_shape; |
89 | TF_RETURN_IF_ERROR(c->GetAttr("validate_shape" , &validate_shape)); |
90 | if (validate_shape) { |
91 | return shape_inference::MergeBothInputsShapeFn(c); |
92 | } |
93 | |
94 | c->set_output(0, c->input(1)); |
95 | return OkStatus(); |
96 | }); |
97 | |
98 | REGISTER_OP("AssignAdd" ) |
99 | .Input("ref: Ref(T)" ) |
100 | .Input("value: T" ) |
101 | .Output("output_ref: Ref(T)" ) |
102 | .Attr("T: numbertype" ) |
103 | .Attr("use_locking: bool = false" ) |
104 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
105 | |
106 | REGISTER_OP("AssignSub" ) |
107 | .Input("ref: Ref(T)" ) |
108 | .Input("value: T" ) |
109 | .Output("output_ref: Ref(T)" ) |
110 | .Attr("T: numbertype" ) |
111 | .Attr("use_locking: bool = false" ) |
112 | .SetShapeFn(shape_inference::MergeBothInputsShapeFn); |
113 | |
114 | namespace { |
115 | |
116 | Status ScatterUpdateShape(InferenceContext* c) { |
117 | ShapeHandle var_shape = c->input(0); |
118 | ShapeHandle indices_shape = c->input(1); |
119 | |
120 | ShapeHandle unused_updates_shape; |
121 | ShapeHandle concat; |
122 | ShapeHandle var_subshape; |
123 | TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); |
124 | TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); |
125 | TF_RETURN_IF_ERROR( |
126 | InferenceContext::Rank(c->input(2)) == 0 |
127 | ? OkStatus() |
128 | : c->Merge(c->input(2), concat, &unused_updates_shape)); |
129 | |
130 | c->set_output(0, var_shape); |
131 | return OkStatus(); |
132 | } |
133 | |
134 | Status ScatterNdUpdateShape(InferenceContext* c) { |
135 | ShapeHandle input_shape = c->input(0); |
136 | if (c->input_handle_shapes_and_types(0) != nullptr) { |
137 | const auto& shape_and_type = *(c->input_handle_shapes_and_types(0)); |
138 | if (!shape_and_type.empty()) { |
139 | input_shape = shape_and_type[0].shape; |
140 | } |
141 | } |
142 | ShapeHandle indices_shape; |
143 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape)); |
144 | ShapeHandle updates_shape; |
145 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape)); |
146 | return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape, |
147 | input_shape); |
148 | } |
149 | |
150 | } // namespace |
151 | |
152 | REGISTER_OP("ScatterUpdate" ) |
153 | .Input("ref: Ref(T)" ) |
154 | .Input("indices: Tindices" ) |
155 | .Input("updates: T" ) |
156 | .Output("output_ref: Ref(T)" ) |
157 | .Attr("T: type" ) |
158 | .Attr("Tindices: {int32, int64}" ) |
159 | .Attr("use_locking: bool = true" ) |
160 | .SetShapeFn(ScatterUpdateShape); |
161 | |
162 | REGISTER_OP("ScatterAdd" ) |
163 | .Input("ref: Ref(T)" ) |
164 | .Input("indices: Tindices" ) |
165 | .Input("updates: T" ) |
166 | .Output("output_ref: Ref(T)" ) |
167 | .Attr("T: numbertype" ) |
168 | .Attr("Tindices: {int32, int64}" ) |
169 | .Attr("use_locking: bool = false" ) |
170 | .SetShapeFn(ScatterUpdateShape); |
171 | |
172 | REGISTER_OP("ScatterSub" ) |
173 | .Input("ref: Ref(T)" ) |
174 | .Input("indices: Tindices" ) |
175 | .Input("updates: T" ) |
176 | .Output("output_ref: Ref(T)" ) |
177 | .Attr("T: numbertype" ) |
178 | .Attr("Tindices: {int32, int64}" ) |
179 | .Attr("use_locking: bool = false" ) |
180 | .SetShapeFn(ScatterUpdateShape); |
181 | |
182 | REGISTER_OP("ScatterMul" ) |
183 | .Input("ref: Ref(T)" ) |
184 | .Input("indices: Tindices" ) |
185 | .Input("updates: T" ) |
186 | .Output("output_ref: Ref(T)" ) |
187 | .Attr("T: numbertype" ) |
188 | .Attr("Tindices: {int32, int64}" ) |
189 | .Attr("use_locking: bool = false" ) |
190 | .SetShapeFn(ScatterUpdateShape); |
191 | |
192 | REGISTER_OP("ScatterDiv" ) |
193 | .Input("ref: Ref(T)" ) |
194 | .Input("indices: Tindices" ) |
195 | .Input("updates: T" ) |
196 | .Output("output_ref: Ref(T)" ) |
197 | .Attr("T: numbertype" ) |
198 | .Attr("Tindices: {int32, int64}" ) |
199 | .Attr("use_locking: bool = false" ) |
200 | .SetShapeFn(ScatterUpdateShape); |
201 | |
202 | REGISTER_OP("ScatterMin" ) |
203 | .Input("ref: Ref(T)" ) |
204 | .Input("indices: Tindices" ) |
205 | .Input("updates: T" ) |
206 | .Output("output_ref: Ref(T)" ) |
207 | .Attr("T: {half, bfloat16, float, double, int32, int64}" ) |
208 | .Attr("Tindices: {int32, int64}" ) |
209 | .Attr("use_locking: bool = false" ) |
210 | .SetShapeFn(ScatterUpdateShape); |
211 | |
212 | REGISTER_OP("ScatterMax" ) |
213 | .Input("ref: Ref(T)" ) |
214 | .Input("indices: Tindices" ) |
215 | .Input("updates: T" ) |
216 | .Output("output_ref: Ref(T)" ) |
217 | .Attr("T: {half, bfloat16, float, double, int32, int64}" ) |
218 | .Attr("Tindices: {int32, int64}" ) |
219 | .Attr("use_locking: bool = false" ) |
220 | .SetShapeFn(ScatterUpdateShape); |
221 | |
222 | REGISTER_OP("ScatterNdUpdate" ) |
223 | .Input("ref: Ref(T)" ) |
224 | .Input("indices: Tindices" ) |
225 | .Input("updates: T" ) |
226 | .Output("output_ref: Ref(T)" ) |
227 | .Attr("T: type" ) |
228 | .Attr("Tindices: {int32, int64}" ) |
229 | .Attr("use_locking: bool = true" ) |
230 | .SetShapeFn(ScatterNdUpdateShape); |
231 | |
232 | REGISTER_OP("ResourceScatterNdUpdate" ) |
233 | .Input("ref: resource" ) |
234 | .Input("indices: Tindices" ) |
235 | .Input("updates: T" ) |
236 | .Attr("T: type" ) |
237 | .Attr("Tindices: {int32, int64}" ) |
238 | .Attr("use_locking: bool = true" ) |
239 | .SetShapeFn(ScatterNdUpdateShape); |
240 | |
241 | REGISTER_OP("ResourceScatterNdAdd" ) |
242 | .Input("ref: resource" ) |
243 | .Input("indices: Tindices" ) |
244 | .Input("updates: T" ) |
245 | .Attr("T: type" ) |
246 | .Attr("Tindices: {int32, int64}" ) |
247 | .Attr("use_locking: bool = true" ) |
248 | .SetShapeFn(ScatterNdUpdateShape); |
249 | |
250 | REGISTER_OP("ResourceScatterNdSub" ) |
251 | .Input("ref: resource" ) |
252 | .Input("indices: Tindices" ) |
253 | .Input("updates: T" ) |
254 | .Attr("T: type" ) |
255 | .Attr("Tindices: {int32, int64}" ) |
256 | .Attr("use_locking: bool = true" ) |
257 | .SetShapeFn(ScatterNdUpdateShape); |
258 | |
259 | REGISTER_OP("ResourceScatterNdMin" ) |
260 | .Input("ref: resource" ) |
261 | .Input("indices: Tindices" ) |
262 | .Input("updates: T" ) |
263 | .Attr("T: type" ) |
264 | .Attr("Tindices: {int32, int64}" ) |
265 | .Attr("use_locking: bool = true" ) |
266 | .SetShapeFn(ScatterNdUpdateShape); |
267 | |
268 | REGISTER_OP("ResourceScatterNdMax" ) |
269 | .Input("ref: resource" ) |
270 | .Input("indices: Tindices" ) |
271 | .Input("updates: T" ) |
272 | .Attr("T: type" ) |
273 | .Attr("Tindices: {int32, int64}" ) |
274 | .Attr("use_locking: bool = true" ) |
275 | .SetShapeFn(ScatterNdUpdateShape); |
276 | |
277 | REGISTER_OP("ScatterNdAdd" ) |
278 | .Input("ref: Ref(T)" ) |
279 | .Input("indices: Tindices" ) |
280 | .Input("updates: T" ) |
281 | .Output("output_ref: Ref(T)" ) |
282 | .Attr("T: numbertype" ) |
283 | .Attr("Tindices: {int32, int64}" ) |
284 | .Attr("use_locking: bool = false" ) |
285 | .SetShapeFn(ScatterNdUpdateShape); |
286 | |
287 | REGISTER_OP("ScatterNdSub" ) |
288 | .Input("ref: Ref(T)" ) |
289 | .Input("indices: Tindices" ) |
290 | .Input("updates: T" ) |
291 | .Output("output_ref: Ref(T)" ) |
292 | .Attr("T: numbertype" ) |
293 | .Attr("Tindices: {int32, int64}" ) |
294 | .Attr("use_locking: bool = false" ) |
295 | .SetShapeFn(ScatterNdUpdateShape); |
296 | |
297 | REGISTER_OP("ScatterNdMax" ) |
298 | .Input("ref: Ref(T)" ) |
299 | .Input("indices: Tindices" ) |
300 | .Input("updates: T" ) |
301 | .Output("output_ref: Ref(T)" ) |
302 | .Attr("T: numbertype" ) |
303 | .Attr("Tindices: {int32, int64}" ) |
304 | .Attr("use_locking: bool = false" ) |
305 | .SetShapeFn(ScatterNdUpdateShape); |
306 | |
307 | REGISTER_OP("ScatterNdMin" ) |
308 | .Input("ref: Ref(T)" ) |
309 | .Input("indices: Tindices" ) |
310 | .Input("updates: T" ) |
311 | .Output("output_ref: Ref(T)" ) |
312 | .Attr("T: numbertype" ) |
313 | .Attr("Tindices: {int32, int64}" ) |
314 | .Attr("use_locking: bool = false" ) |
315 | .SetShapeFn(ScatterNdUpdateShape); |
316 | |
317 | REGISTER_OP("CountUpTo" ) |
318 | .Input("ref: Ref(T)" ) |
319 | .Output("output: T" ) |
320 | .Attr("limit: int" ) |
321 | .Attr("T: {int32, int64}" ) |
322 | .SetShapeFn([](InferenceContext* c) { |
323 | ShapeHandle output; |
324 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &output)); |
325 | c->set_output(0, output); |
326 | return OkStatus(); |
327 | }); |
328 | |
329 | REGISTER_OP("ResourceCountUpTo" ) |
330 | .Input("resource: resource" ) |
331 | .Output("output: T" ) |
332 | .Attr("limit: int" ) |
333 | .Attr("T: {int32, int64}" ) |
334 | .SetShapeFn([](InferenceContext* c) { |
335 | auto* handle_data = c->input_handle_shapes_and_types(0); |
336 | if (handle_data == nullptr || handle_data->empty()) { |
337 | return errors::InvalidArgument("Handle has no shape/type information." ); |
338 | } |
339 | shape_inference::ShapeAndType shape_and_type = (*handle_data)[0]; |
340 | DataType value_dtype; |
341 | TF_RETURN_IF_ERROR(c->GetAttr("T" , &value_dtype)); |
342 | if (value_dtype != shape_and_type.dtype) { |
343 | return errors::InvalidArgument( |
344 | "Data types do not match: " , DataTypeString(value_dtype), " and " , |
345 | DataTypeString(shape_and_type.dtype)); |
346 | } |
347 | ShapeHandle output; |
348 | TF_RETURN_IF_ERROR(c->WithRank(shape_and_type.shape, 0, &output)); |
349 | c->set_output(0, output); |
350 | return OkStatus(); |
351 | }); |
352 | |
353 | } // namespace tensorflow |
354 | |