1// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14// ============================================================================
15
16#include <algorithm>
17
18#include "tensorflow/core/framework/common_shape_fns.h"
19#include "tensorflow/core/framework/function.h"
20#include "tensorflow/core/framework/node_def_util.h"
21#include "tensorflow/core/framework/op.h"
22#include "tensorflow/core/framework/resource_mgr.h"
23#include "tensorflow/core/framework/shape_inference.h"
24#include "tensorflow/core/lib/core/errors.h"
25
26using ::tensorflow::shape_inference::InferenceContext;
27using ::tensorflow::shape_inference::ShapeAndType;
28using ::tensorflow::shape_inference::ShapeHandle;
29
30namespace tensorflow {
31
32namespace {
33
34Status ReadVariableShapeFn(InferenceContext* c) {
35 // The user can add a "_shape" atribute to ReadVariableOp nodes. It is
36 // useful for inferring shapes in a function, when no shape information
37 // is passed about input resources. The user can annotate the graph using
38 // the variable capture list of the function.
39 // If the "_shape" attribute is found, it is used to set the output shape.
40 PartialTensorShape p;
41 Status annotation_found_status = c->GetAttr("_shape", &p);
42 if (annotation_found_status.ok()) {
43 ShapeHandle s;
44 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
45 c->set_output(0, s);
46 } else {
47 std::vector<ShapeAndType> shape_and_type;
48 TF_RETURN_IF_ERROR(
49 shape_inference::ValidateVariableResourceHandle(c, &shape_and_type));
50 c->set_output(0, shape_and_type[0].shape);
51 if (shape_and_type[0].dtype == DT_VARIANT && shape_and_type.size() > 1) {
52 std::vector<ShapeAndType> variant_shape_and_type;
53 std::copy(shape_and_type.begin() + 1, shape_and_type.end(),
54 std::back_inserter(variant_shape_and_type));
55 c->set_output_handle_shapes_and_types(0, variant_shape_and_type);
56 }
57 }
58 return OkStatus();
59}
60
61Status ReadVariablesShapeFn(InferenceContext* c) {
62 int n;
63 TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
64 DataTypeVector value_dtypes;
65 TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes));
66 if (n != value_dtypes.size()) {
67 return errors::InvalidArgument(
68 "Mismatched number of arguments to ReadVariablesOp");
69 }
70 for (int i = 0; i < n; ++i) {
71 ShapeAndType shape_and_type;
72 auto* handle_data = c->input_handle_shapes_and_types(i);
73 if (handle_data == nullptr || handle_data->empty()) {
74 shape_and_type.shape = c->UnknownShape();
75 shape_and_type.dtype = DT_INVALID;
76 } else {
77 shape_and_type = (*handle_data)[0];
78 if (shape_and_type.dtype != value_dtypes[i]) {
79 return errors::InvalidArgument(
80 "Trying to read variable with wrong dtype. "
81 "Expected ",
82 DataTypeString(shape_and_type.dtype), " got ",
83 DataTypeString(value_dtypes[i]));
84 }
85 }
86 c->set_output(i, shape_and_type.shape);
87 }
88 return OkStatus();
89}
90
91} // namespace
92
93REGISTER_OP("VarHandleOp")
94 .Attr("container: string = ''")
95 .Attr("shared_name: string = ''")
96 .Attr("dtype: type")
97 .Attr("shape: shape")
98 .Attr("allowed_devices: list(string) = []")
99 .Output("resource: resource")
100 .SetIsStateful()
101 .SetShapeFn([](InferenceContext* c) {
102 c->set_output(0, c->Scalar());
103 DataType t;
104 TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
105 PartialTensorShape p;
106 TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
107 ShapeHandle s;
108 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
109 c->set_output_handle_shapes_and_types(0,
110 std::vector<ShapeAndType>{{s, t}});
111
112 return OkStatus();
113 });
114
115REGISTER_OP("_VarHandlesOp")
116 .Attr("containers: list(string)")
117 .Attr("shared_names: list(string)")
118 .Attr("N: int >= 0")
119 .Attr("dtypes: list(type)")
120 .Attr("shapes: list(shape)")
121 .Output("resources: N * resource")
122 .SetIsStateful()
123 .SetShapeFn([](InferenceContext* c) {
124 int n;
125 TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
126 DataTypeVector dtypes;
127 TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
128 std::vector<PartialTensorShape> shapes;
129 TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
130 if (dtypes.size() != n) {
131 return errors::InvalidArgument("Mismatched number of dtypes (n=", n,
132 ", num dtypes=", dtypes.size(), ")");
133 }
134 if (shapes.size() != n) {
135 return errors::InvalidArgument("Mismatched number of shapes (n=", n,
136 ", num shapes=", shapes.size(), ")");
137 }
138 for (int i = 0; i < n; ++i) {
139 c->set_output(i, c->Scalar());
140 ShapeHandle s;
141 TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s));
142 c->set_output_handle_shapes_and_types(
143 i, std::vector<ShapeAndType>{{s, dtypes[i]}});
144 }
145
146 return OkStatus();
147 });
148
149REGISTER_OP("ReadVariableOp")
150 .Input("resource: resource")
151 .Output("value: dtype")
152 .Attr("dtype: type")
153 .SetShapeFn(ReadVariableShapeFn);
154
155REGISTER_OP("_ReadVariablesOp")
156 .Attr("N: int >= 0")
157 .Input("resources: N * resource")
158 .Output("values: dtypes")
159 .Attr("dtypes: list(type)")
160 .SetShapeFn(ReadVariablesShapeFn);
161
162Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) {
163 // clang-format off
164 *g = FunctionDefHelper::Define(
165 // Arg defs
166 {"x: resource", "dy: float"},
167 // Ret val defs
168 {"dy: float"},
169 // Attr defs
170 {},
171 // Nodes
172 {});
173 // clang-format on
174 return OkStatus();
175}
176REGISTER_OP_GRADIENT("ReadVariableOp", ReadGrad);
177
178REGISTER_OP("DestroyResourceOp")
179 .Input("resource: resource")
180 .Attr("ignore_lookup_error: bool = true")
181 .SetIsStateful()
182 .SetShapeFn(shape_inference::NoOutputs);
183
184Status CreateAssignShapeFn(InferenceContext* c) {
185 std::vector<ShapeAndType> handle_shape_and_type;
186 TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle(
187 c, &handle_shape_and_type));
188
189 ShapeHandle value_shape = c->input(1);
190 ShapeHandle unused;
191 TF_RETURN_IF_ERROR(
192 c->Merge(handle_shape_and_type[0].shape, value_shape, &unused));
193
194 if (handle_shape_and_type[0].dtype == DT_VARIANT &&
195 handle_shape_and_type.size() > 1 &&
196 c->input_handle_shapes_and_types(1) != nullptr) {
197 auto* value_handle_shape_and_type = c->input_handle_shapes_and_types(1);
198 if (value_handle_shape_and_type->size() !=
199 handle_shape_and_type.size() - 1) {
200 return errors::InvalidArgument(
201 "Incompatible handle variant shape_and_type size and input "
202 "shape_and_type size: ",
203 handle_shape_and_type.size() - 1, " vs. ",
204 value_handle_shape_and_type->size());
205 }
206 }
207 return OkStatus();
208}
209
210REGISTER_OP("AssignVariableOp")
211 .Input("resource: resource")
212 .Input("value: dtype")
213 .Attr("dtype: type")
214 .Attr("validate_shape: bool = false")
215 .SetShapeFn(CreateAssignShapeFn);
216
217REGISTER_OP("AssignAddVariableOp")
218 .Input("resource: resource")
219 .Input("value: dtype")
220 .Attr("dtype: type")
221 .SetShapeFn(CreateAssignShapeFn);
222
223REGISTER_OP("AssignSubVariableOp")
224 .Input("resource: resource")
225 .Input("value: dtype")
226 .Attr("dtype: type")
227 .SetShapeFn(CreateAssignShapeFn);
228
229REGISTER_OP("VarIsInitializedOp")
230 .Input("resource: resource")
231 .Output("is_initialized: bool")
232 .SetShapeFn(tensorflow::shape_inference::ScalarShape);
233
234Status VariableShapeShapeFn(InferenceContext* c) {
235 auto* handle_data = c->input_handle_shapes_and_types(0);
236 if (handle_data == nullptr || handle_data->empty()) {
237 c->set_output(0, c->Vector(c->UnknownDim()));
238 return OkStatus();
239 }
240 ShapeHandle var_shape = (*handle_data)[0].shape;
241 int64_t rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
242 : InferenceContext::kUnknownDim;
243 c->set_output(0, c->Vector(rank));
244 return OkStatus();
245}
246
247REGISTER_OP("VariableShape")
248 .Input("input: resource")
249 .Output("output: out_type")
250 .Attr("out_type: {int32, int64} = DT_INT32")
251 .SetShapeFn(VariableShapeShapeFn);
252
253REGISTER_OP("ResourceGather")
254 .Input("resource: resource")
255 .Input("indices: Tindices")
256 .Attr("batch_dims: int = 0")
257 .Attr("validate_indices: bool = true")
258 .Output("output: dtype")
259 .Attr("dtype: type")
260 .Attr("Tindices: {int32,int64}")
261 .SetShapeFn([](InferenceContext* c) {
262 std::vector<ShapeAndType> handle_shape_and_type;
263 TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle(
264 c, &handle_shape_and_type));
265
266 ShapeHandle indices_shape = c->input(1);
267
268 ShapeHandle unused;
269 int32_t batch_dims;
270 TF_RETURN_IF_ERROR(c->GetAttr("batch_dims", &batch_dims));
271 if (batch_dims < 0)
272 return errors::InvalidArgument("batch_dims is negative (", batch_dims,
273 ")");
274
275 TF_RETURN_IF_ERROR(c->WithRankAtLeast(handle_shape_and_type[0].shape,
276 batch_dims + 1, &unused));
277
278 TF_RETURN_IF_ERROR(
279 c->WithRankAtLeast(indices_shape, batch_dims, &unused));
280
281 ShapeHandle params_subshape1;
282 TF_RETURN_IF_ERROR(c->Subshape(handle_shape_and_type[0].shape, 0,
283 batch_dims, &params_subshape1));
284
285 ShapeHandle params_subshape2;
286 TF_RETURN_IF_ERROR(c->Subshape(handle_shape_and_type[0].shape,
287 batch_dims + 1, &params_subshape2));
288
289 ShapeHandle indices_subshape;
290 TF_RETURN_IF_ERROR(
291 c->Subshape(indices_shape, batch_dims, &indices_subshape));
292
293 // The out shape is params_shape[:batch_dims] +
294 // indices_shape[batch_dims:] + params_shape[batch_dims+1:].
295 ShapeHandle out;
296 TF_RETURN_IF_ERROR(
297 c->Concatenate(params_subshape1, indices_subshape, &out));
298 TF_RETURN_IF_ERROR(c->Concatenate(out, params_subshape2, &out));
299
300 c->set_output(0, out);
301 if (handle_shape_and_type[0].dtype == DT_VARIANT &&
302 !handle_shape_and_type.empty()) {
303 std::vector<ShapeAndType> variant_shape_and_type;
304 std::copy(handle_shape_and_type.begin() + 1,
305 handle_shape_and_type.end(),
306 std::back_inserter(variant_shape_and_type));
307 c->set_output_handle_shapes_and_types(0, variant_shape_and_type);
308 }
309 return OkStatus();
310 });
311
312REGISTER_OP("ResourceGatherNd")
313 .Input("resource: resource")
314 .Input("indices: Tindices")
315 .Output("output: dtype")
316 .Attr("dtype: type")
317 .Attr("Tindices: {int32,int64}")
318 .SetShapeFn(shape_inference::GatherNdShape);
319
320namespace {
321
322Status ResourceScatterUpdateShape(InferenceContext* c) {
323 std::vector<ShapeAndType> handle_shape_and_type;
324 TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle(
325 c, &handle_shape_and_type));
326 ShapeHandle var_shape = handle_shape_and_type[0].shape;
327 ShapeHandle indices_shape = c->input(1);
328
329 ShapeHandle unused_updates_shape;
330 ShapeHandle concat;
331 ShapeHandle var_subshape;
332 TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape));
333 TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat));
334 TF_RETURN_IF_ERROR(
335 InferenceContext::Rank(c->input(2)) == 0
336 ? OkStatus()
337 : c->Merge(c->input(2), concat, &unused_updates_shape));
338 if (handle_shape_and_type[0].dtype == DT_VARIANT &&
339 handle_shape_and_type.size() > 1 &&
340 c->input_handle_shapes_and_types(2) != nullptr) {
341 auto* value_handle_shape_and_type = c->input_handle_shapes_and_types(2);
342 if (value_handle_shape_and_type->size() !=
343 handle_shape_and_type.size() - 1) {
344 return errors::InvalidArgument(
345 "Incompatible handle variant shape_and_type size and input "
346 "shape_and_type size: ",
347 handle_shape_and_type.size() - 1, " vs. ",
348 value_handle_shape_and_type->size());
349 }
350 }
351 return OkStatus();
352}
353
354} // namespace
355
356REGISTER_OP("ResourceScatterAdd")
357 .Input("resource: resource")
358 .Input("indices: Tindices")
359 .Input("updates: dtype")
360 .Attr("dtype: numbertype")
361 .Attr("Tindices: {int32, int64}")
362 .SetShapeFn(ResourceScatterUpdateShape);
363
364REGISTER_OP("ResourceScatterSub")
365 .Input("resource: resource")
366 .Input("indices: Tindices")
367 .Input("updates: dtype")
368 .Attr("dtype: numbertype")
369 .Attr("Tindices: {int32, int64}")
370 .SetShapeFn(ResourceScatterUpdateShape);
371
372REGISTER_OP("ResourceScatterMul")
373 .Input("resource: resource")
374 .Input("indices: Tindices")
375 .Input("updates: dtype")
376 .Attr("dtype: numbertype")
377 .Attr("Tindices: {int32, int64}")
378 .SetShapeFn(ResourceScatterUpdateShape);
379
380REGISTER_OP("ResourceScatterDiv")
381 .Input("resource: resource")
382 .Input("indices: Tindices")
383 .Input("updates: dtype")
384 .Attr("dtype: numbertype")
385 .Attr("Tindices: {int32, int64}")
386 .SetShapeFn(ResourceScatterUpdateShape);
387
388REGISTER_OP("ResourceScatterMin")
389 .Input("resource: resource")
390 .Input("indices: Tindices")
391 .Input("updates: dtype")
392 .Attr("dtype: numbertype")
393 .Attr("Tindices: {int32, int64}")
394 .SetShapeFn(ResourceScatterUpdateShape);
395
396REGISTER_OP("ResourceScatterMax")
397 .Input("resource: resource")
398 .Input("indices: Tindices")
399 .Input("updates: dtype")
400 .Attr("dtype: numbertype")
401 .Attr("Tindices: {int32, int64}")
402 .SetShapeFn(ResourceScatterUpdateShape);
403
404REGISTER_OP("ResourceScatterUpdate")
405 .Input("resource: resource")
406 .Input("indices: Tindices")
407 .Input("updates: dtype")
408 .Attr("dtype: type")
409 .Attr("Tindices: {int32, int64}")
410 .SetShapeFn(ResourceScatterUpdateShape);
411
412REGISTER_OP("MutexV2")
413 .Attr("container: string = ''")
414 .Attr("shared_name: string = ''")
415 .Output("resource: resource")
416 .SetIsStateful()
417 .SetShapeFn([](InferenceContext* c) {
418 c->set_output(0, c->Scalar());
419 return OkStatus();
420 });
421
422REGISTER_OP("MutexLock")
423 .Input("mutex: resource")
424 .Output("mutex_lock: variant")
425 .SetIsStateful()
426 .SetTypeConstructor(full_type::Nullary(TFT_MUTEX_LOCK))
427 .SetShapeFn([](InferenceContext* c) {
428 c->set_output(0, c->Scalar());
429 return OkStatus();
430 });
431
432REGISTER_OP("ConsumeMutexLock")
433 .Input("mutex_lock: variant")
434 .SetIsStateful()
435 .SetShapeFn([](InferenceContext* c) { return OkStatus(); });
436
437REGISTER_OP("DisableCopyOnRead")
438 .Input("resource: resource")
439 .SetShapeFn(shape_inference::NoOutputs);
440
441} // namespace tensorflow
442