1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18
19namespace tensorflow {
20
21using shape_inference::InferenceContext;
22using shape_inference::ShapeHandle;
23
24REGISTER_OP("VariableV2")
25 .Output("ref: Ref(dtype)")
26 .Attr("shape: shape")
27 .Attr("dtype: type")
28 .Attr("container: string = ''")
29 .Attr("shared_name: string = ''")
30 .SetIsStateful()
31 .SetShapeFn(shape_inference::ExplicitShape);
32
33REGISTER_OP("Variable")
34 .Output("ref: Ref(dtype)")
35 .Attr("shape: shape")
36 .Attr("dtype: type")
37 .Attr("container: string = ''")
38 .Attr("shared_name: string = ''")
39 .SetIsStateful()
40 .SetShapeFn([](InferenceContext* c) {
41 PartialTensorShape shape;
42 TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape));
43
44 // Variable has legacy behavior where we cannot tell the difference
45 // between a scalar shape attribute and 'unknown shape'. So if the shape
46 // is a scalar, we return an unknown shape.
47 if (shape.dims() <= 0) {
48 return shape_inference::UnknownShape(c);
49 }
50
51 ShapeHandle out;
52 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shape, &out));
53 c->set_output(0, out);
54 return OkStatus();
55 });
56
57REGISTER_OP("IsVariableInitialized")
58 .Input("ref: Ref(dtype)")
59 .Output("is_initialized: bool")
60 .Attr("dtype: type")
61 .SetAllowsUninitializedInput()
62 .SetShapeFn(shape_inference::ScalarShape);
63
64REGISTER_OP("TemporaryVariable")
65 .Output("ref: Ref(dtype)")
66 .Attr("shape: shape")
67 .Attr("dtype: type")
68 .Attr("var_name: string = ''")
69 .SetIsStateful()
70 .SetShapeFn(shape_inference::ExplicitShape);
71
72REGISTER_OP("DestroyTemporaryVariable")
73 .Input("ref: Ref(T)")
74 .Output("value: T")
75 .Attr("T: type")
76 .Attr("var_name: string")
77 .SetShapeFn(shape_inference::UnchangedShape);
78
79REGISTER_OP("Assign")
80 .Input("ref: Ref(T)")
81 .Input("value: T")
82 .Output("output_ref: Ref(T)")
83 .Attr("T: type")
84 .Attr("validate_shape: bool = true")
85 .Attr("use_locking: bool = true")
86 .SetAllowsUninitializedInput()
87 .SetShapeFn([](InferenceContext* c) {
88 bool validate_shape;
89 TF_RETURN_IF_ERROR(c->GetAttr("validate_shape", &validate_shape));
90 if (validate_shape) {
91 return shape_inference::MergeBothInputsShapeFn(c);
92 }
93
94 c->set_output(0, c->input(1));
95 return OkStatus();
96 });
97
98REGISTER_OP("AssignAdd")
99 .Input("ref: Ref(T)")
100 .Input("value: T")
101 .Output("output_ref: Ref(T)")
102 .Attr("T: numbertype")
103 .Attr("use_locking: bool = false")
104 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
105
106REGISTER_OP("AssignSub")
107 .Input("ref: Ref(T)")
108 .Input("value: T")
109 .Output("output_ref: Ref(T)")
110 .Attr("T: numbertype")
111 .Attr("use_locking: bool = false")
112 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
113
114namespace {
115
116Status ScatterUpdateShape(InferenceContext* c) {
117 ShapeHandle var_shape = c->input(0);
118 ShapeHandle indices_shape = c->input(1);
119
120 ShapeHandle unused_updates_shape;
121 ShapeHandle concat;
122 ShapeHandle var_subshape;
123 TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
124 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
125 TF_RETURN_IF_ERROR(
126 InferenceContext::Rank(c->input(2)) == 0
127 ? OkStatus()
128 : c->Merge(c->input(2), concat, &unused_updates_shape));
129
130 c->set_output(0, var_shape);
131 return OkStatus();
132}
133
134Status ScatterNdUpdateShape(InferenceContext* c) {
135 ShapeHandle input_shape = c->input(0);
136 if (c->input_handle_shapes_and_types(0) != nullptr) {
137 const auto& shape_and_type = *(c->input_handle_shapes_and_types(0));
138 if (!shape_and_type.empty()) {
139 input_shape = shape_and_type[0].shape;
140 }
141 }
142 ShapeHandle indices_shape;
143 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
144 ShapeHandle updates_shape;
145 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(2), 1, &updates_shape));
146 return shape_inference::ScatterNdShapeHelper(c, indices_shape, updates_shape,
147 input_shape);
148}
149
150} // namespace
151
152REGISTER_OP("ScatterUpdate")
153 .Input("ref: Ref(T)")
154 .Input("indices: Tindices")
155 .Input("updates: T")
156 .Output("output_ref: Ref(T)")
157 .Attr("T: type")
158 .Attr("Tindices: {int32, int64}")
159 .Attr("use_locking: bool = true")
160 .SetShapeFn(ScatterUpdateShape);
161
162REGISTER_OP("ScatterAdd")
163 .Input("ref: Ref(T)")
164 .Input("indices: Tindices")
165 .Input("updates: T")
166 .Output("output_ref: Ref(T)")
167 .Attr("T: numbertype")
168 .Attr("Tindices: {int32, int64}")
169 .Attr("use_locking: bool = false")
170 .SetShapeFn(ScatterUpdateShape);
171
172REGISTER_OP("ScatterSub")
173 .Input("ref: Ref(T)")
174 .Input("indices: Tindices")
175 .Input("updates: T")
176 .Output("output_ref: Ref(T)")
177 .Attr("T: numbertype")
178 .Attr("Tindices: {int32, int64}")
179 .Attr("use_locking: bool = false")
180 .SetShapeFn(ScatterUpdateShape);
181
182REGISTER_OP("ScatterMul")
183 .Input("ref: Ref(T)")
184 .Input("indices: Tindices")
185 .Input("updates: T")
186 .Output("output_ref: Ref(T)")
187 .Attr("T: numbertype")
188 .Attr("Tindices: {int32, int64}")
189 .Attr("use_locking: bool = false")
190 .SetShapeFn(ScatterUpdateShape);
191
192REGISTER_OP("ScatterDiv")
193 .Input("ref: Ref(T)")
194 .Input("indices: Tindices")
195 .Input("updates: T")
196 .Output("output_ref: Ref(T)")
197 .Attr("T: numbertype")
198 .Attr("Tindices: {int32, int64}")
199 .Attr("use_locking: bool = false")
200 .SetShapeFn(ScatterUpdateShape);
201
202REGISTER_OP("ScatterMin")
203 .Input("ref: Ref(T)")
204 .Input("indices: Tindices")
205 .Input("updates: T")
206 .Output("output_ref: Ref(T)")
207 .Attr("T: {half, bfloat16, float, double, int32, int64}")
208 .Attr("Tindices: {int32, int64}")
209 .Attr("use_locking: bool = false")
210 .SetShapeFn(ScatterUpdateShape);
211
212REGISTER_OP("ScatterMax")
213 .Input("ref: Ref(T)")
214 .Input("indices: Tindices")
215 .Input("updates: T")
216 .Output("output_ref: Ref(T)")
217 .Attr("T: {half, bfloat16, float, double, int32, int64}")
218 .Attr("Tindices: {int32, int64}")
219 .Attr("use_locking: bool = false")
220 .SetShapeFn(ScatterUpdateShape);
221
222REGISTER_OP("ScatterNdUpdate")
223 .Input("ref: Ref(T)")
224 .Input("indices: Tindices")
225 .Input("updates: T")
226 .Output("output_ref: Ref(T)")
227 .Attr("T: type")
228 .Attr("Tindices: {int32, int64}")
229 .Attr("use_locking: bool = true")
230 .SetShapeFn(ScatterNdUpdateShape);
231
232REGISTER_OP("ResourceScatterNdUpdate")
233 .Input("ref: resource")
234 .Input("indices: Tindices")
235 .Input("updates: T")
236 .Attr("T: type")
237 .Attr("Tindices: {int32, int64}")
238 .Attr("use_locking: bool = true")
239 .SetShapeFn(ScatterNdUpdateShape);
240
241REGISTER_OP("ResourceScatterNdAdd")
242 .Input("ref: resource")
243 .Input("indices: Tindices")
244 .Input("updates: T")
245 .Attr("T: type")
246 .Attr("Tindices: {int32, int64}")
247 .Attr("use_locking: bool = true")
248 .SetShapeFn(ScatterNdUpdateShape);
249
250REGISTER_OP("ResourceScatterNdSub")
251 .Input("ref: resource")
252 .Input("indices: Tindices")
253 .Input("updates: T")
254 .Attr("T: type")
255 .Attr("Tindices: {int32, int64}")
256 .Attr("use_locking: bool = true")
257 .SetShapeFn(ScatterNdUpdateShape);
258
259REGISTER_OP("ResourceScatterNdMin")
260 .Input("ref: resource")
261 .Input("indices: Tindices")
262 .Input("updates: T")
263 .Attr("T: type")
264 .Attr("Tindices: {int32, int64}")
265 .Attr("use_locking: bool = true")
266 .SetShapeFn(ScatterNdUpdateShape);
267
268REGISTER_OP("ResourceScatterNdMax")
269 .Input("ref: resource")
270 .Input("indices: Tindices")
271 .Input("updates: T")
272 .Attr("T: type")
273 .Attr("Tindices: {int32, int64}")
274 .Attr("use_locking: bool = true")
275 .SetShapeFn(ScatterNdUpdateShape);
276
277REGISTER_OP("ScatterNdAdd")
278 .Input("ref: Ref(T)")
279 .Input("indices: Tindices")
280 .Input("updates: T")
281 .Output("output_ref: Ref(T)")
282 .Attr("T: numbertype")
283 .Attr("Tindices: {int32, int64}")
284 .Attr("use_locking: bool = false")
285 .SetShapeFn(ScatterNdUpdateShape);
286
287REGISTER_OP("ScatterNdSub")
288 .Input("ref: Ref(T)")
289 .Input("indices: Tindices")
290 .Input("updates: T")
291 .Output("output_ref: Ref(T)")
292 .Attr("T: numbertype")
293 .Attr("Tindices: {int32, int64}")
294 .Attr("use_locking: bool = false")
295 .SetShapeFn(ScatterNdUpdateShape);
296
297REGISTER_OP("ScatterNdMax")
298 .Input("ref: Ref(T)")
299 .Input("indices: Tindices")
300 .Input("updates: T")
301 .Output("output_ref: Ref(T)")
302 .Attr("T: numbertype")
303 .Attr("Tindices: {int32, int64}")
304 .Attr("use_locking: bool = false")
305 .SetShapeFn(ScatterNdUpdateShape);
306
307REGISTER_OP("ScatterNdMin")
308 .Input("ref: Ref(T)")
309 .Input("indices: Tindices")
310 .Input("updates: T")
311 .Output("output_ref: Ref(T)")
312 .Attr("T: numbertype")
313 .Attr("Tindices: {int32, int64}")
314 .Attr("use_locking: bool = false")
315 .SetShapeFn(ScatterNdUpdateShape);
316
317REGISTER_OP("CountUpTo")
318 .Input("ref: Ref(T)")
319 .Output("output: T")
320 .Attr("limit: int")
321 .Attr("T: {int32, int64}")
322 .SetShapeFn([](InferenceContext* c) {
323 ShapeHandle output;
324 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &output));
325 c->set_output(0, output);
326 return OkStatus();
327 });
328
329REGISTER_OP("ResourceCountUpTo")
330 .Input("resource: resource")
331 .Output("output: T")
332 .Attr("limit: int")
333 .Attr("T: {int32, int64}")
334 .SetShapeFn([](InferenceContext* c) {
335 auto* handle_data = c->input_handle_shapes_and_types(0);
336 if (handle_data == nullptr || handle_data->empty()) {
337 return errors::InvalidArgument("Handle has no shape/type information.");
338 }
339 shape_inference::ShapeAndType shape_and_type = (*handle_data)[0];
340 DataType value_dtype;
341 TF_RETURN_IF_ERROR(c->GetAttr("T", &value_dtype));
342 if (value_dtype != shape_and_type.dtype) {
343 return errors::InvalidArgument(
344 "Data types do not match: ", DataTypeString(value_dtype), " and ",
345 DataTypeString(shape_and_type.dtype));
346 }
347 ShapeHandle output;
348 TF_RETURN_IF_ERROR(c->WithRank(shape_and_type.shape, 0, &output));
349 c->set_output(0, output);
350 return OkStatus();
351 });
352
353} // namespace tensorflow
354