1 | // Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | // ============================================================================ |
15 | |
16 | #include <algorithm> |
17 | |
18 | #include "tensorflow/core/framework/common_shape_fns.h" |
19 | #include "tensorflow/core/framework/function.h" |
20 | #include "tensorflow/core/framework/node_def_util.h" |
21 | #include "tensorflow/core/framework/op.h" |
22 | #include "tensorflow/core/framework/resource_mgr.h" |
23 | #include "tensorflow/core/framework/shape_inference.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | |
26 | using ::tensorflow::shape_inference::InferenceContext; |
27 | using ::tensorflow::shape_inference::ShapeAndType; |
28 | using ::tensorflow::shape_inference::ShapeHandle; |
29 | |
30 | namespace tensorflow { |
31 | |
32 | namespace { |
33 | |
34 | Status ReadVariableShapeFn(InferenceContext* c) { |
35 | // The user can add a "_shape" atribute to ReadVariableOp nodes. It is |
36 | // useful for inferring shapes in a function, when no shape information |
37 | // is passed about input resources. The user can annotate the graph using |
38 | // the variable capture list of the function. |
39 | // If the "_shape" attribute is found, it is used to set the output shape. |
40 | PartialTensorShape p; |
41 | Status annotation_found_status = c->GetAttr("_shape" , &p); |
42 | if (annotation_found_status.ok()) { |
43 | ShapeHandle s; |
44 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); |
45 | c->set_output(0, s); |
46 | } else { |
47 | std::vector<ShapeAndType> shape_and_type; |
48 | TF_RETURN_IF_ERROR( |
49 | shape_inference::ValidateVariableResourceHandle(c, &shape_and_type)); |
50 | c->set_output(0, shape_and_type[0].shape); |
51 | if (shape_and_type[0].dtype == DT_VARIANT && shape_and_type.size() > 1) { |
52 | std::vector<ShapeAndType> variant_shape_and_type; |
53 | std::copy(shape_and_type.begin() + 1, shape_and_type.end(), |
54 | std::back_inserter(variant_shape_and_type)); |
55 | c->set_output_handle_shapes_and_types(0, variant_shape_and_type); |
56 | } |
57 | } |
58 | return OkStatus(); |
59 | } |
60 | |
61 | Status ReadVariablesShapeFn(InferenceContext* c) { |
62 | int n; |
63 | TF_RETURN_IF_ERROR(c->GetAttr("N" , &n)); |
64 | DataTypeVector value_dtypes; |
65 | TF_RETURN_IF_ERROR(c->GetAttr("dtypes" , &value_dtypes)); |
66 | if (n != value_dtypes.size()) { |
67 | return errors::InvalidArgument( |
68 | "Mismatched number of arguments to ReadVariablesOp" ); |
69 | } |
70 | for (int i = 0; i < n; ++i) { |
71 | ShapeAndType shape_and_type; |
72 | auto* handle_data = c->input_handle_shapes_and_types(i); |
73 | if (handle_data == nullptr || handle_data->empty()) { |
74 | shape_and_type.shape = c->UnknownShape(); |
75 | shape_and_type.dtype = DT_INVALID; |
76 | } else { |
77 | shape_and_type = (*handle_data)[0]; |
78 | if (shape_and_type.dtype != value_dtypes[i]) { |
79 | return errors::InvalidArgument( |
80 | "Trying to read variable with wrong dtype. " |
81 | "Expected " , |
82 | DataTypeString(shape_and_type.dtype), " got " , |
83 | DataTypeString(value_dtypes[i])); |
84 | } |
85 | } |
86 | c->set_output(i, shape_and_type.shape); |
87 | } |
88 | return OkStatus(); |
89 | } |
90 | |
91 | } // namespace |
92 | |
93 | REGISTER_OP("VarHandleOp" ) |
94 | .Attr("container: string = ''" ) |
95 | .Attr("shared_name: string = ''" ) |
96 | .Attr("dtype: type" ) |
97 | .Attr("shape: shape" ) |
98 | .Attr("allowed_devices: list(string) = []" ) |
99 | .Output("resource: resource" ) |
100 | .SetIsStateful() |
101 | .SetShapeFn([](InferenceContext* c) { |
102 | c->set_output(0, c->Scalar()); |
103 | DataType t; |
104 | TF_RETURN_IF_ERROR(c->GetAttr("dtype" , &t)); |
105 | PartialTensorShape p; |
106 | TF_RETURN_IF_ERROR(c->GetAttr("shape" , &p)); |
107 | ShapeHandle s; |
108 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s)); |
109 | c->set_output_handle_shapes_and_types(0, |
110 | std::vector<ShapeAndType>{{s, t}}); |
111 | |
112 | return OkStatus(); |
113 | }); |
114 | |
115 | REGISTER_OP("_VarHandlesOp" ) |
116 | .Attr("containers: list(string)" ) |
117 | .Attr("shared_names: list(string)" ) |
118 | .Attr("N: int >= 0" ) |
119 | .Attr("dtypes: list(type)" ) |
120 | .Attr("shapes: list(shape)" ) |
121 | .Output("resources: N * resource" ) |
122 | .SetIsStateful() |
123 | .SetShapeFn([](InferenceContext* c) { |
124 | int n; |
125 | TF_RETURN_IF_ERROR(c->GetAttr("N" , &n)); |
126 | DataTypeVector dtypes; |
127 | TF_RETURN_IF_ERROR(c->GetAttr("dtypes" , &dtypes)); |
128 | std::vector<PartialTensorShape> shapes; |
129 | TF_RETURN_IF_ERROR(c->GetAttr("shapes" , &shapes)); |
130 | if (dtypes.size() != n) { |
131 | return errors::InvalidArgument("Mismatched number of dtypes (n=" , n, |
132 | ", num dtypes=" , dtypes.size(), ")" ); |
133 | } |
134 | if (shapes.size() != n) { |
135 | return errors::InvalidArgument("Mismatched number of shapes (n=" , n, |
136 | ", num shapes=" , shapes.size(), ")" ); |
137 | } |
138 | for (int i = 0; i < n; ++i) { |
139 | c->set_output(i, c->Scalar()); |
140 | ShapeHandle s; |
141 | TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s)); |
142 | c->set_output_handle_shapes_and_types( |
143 | i, std::vector<ShapeAndType>{{s, dtypes[i]}}); |
144 | } |
145 | |
146 | return OkStatus(); |
147 | }); |
148 | |
149 | REGISTER_OP("ReadVariableOp" ) |
150 | .Input("resource: resource" ) |
151 | .Output("value: dtype" ) |
152 | .Attr("dtype: type" ) |
153 | .SetShapeFn(ReadVariableShapeFn); |
154 | |
155 | REGISTER_OP("_ReadVariablesOp" ) |
156 | .Attr("N: int >= 0" ) |
157 | .Input("resources: N * resource" ) |
158 | .Output("values: dtypes" ) |
159 | .Attr("dtypes: list(type)" ) |
160 | .SetShapeFn(ReadVariablesShapeFn); |
161 | |
162 | Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) { |
163 | // clang-format off |
164 | *g = FunctionDefHelper::Define( |
165 | // Arg defs |
166 | {"x: resource" , "dy: float" }, |
167 | // Ret val defs |
168 | {"dy: float" }, |
169 | // Attr defs |
170 | {}, |
171 | // Nodes |
172 | {}); |
173 | // clang-format on |
174 | return OkStatus(); |
175 | } |
176 | REGISTER_OP_GRADIENT("ReadVariableOp" , ReadGrad); |
177 | |
178 | REGISTER_OP("DestroyResourceOp" ) |
179 | .Input("resource: resource" ) |
180 | .Attr("ignore_lookup_error: bool = true" ) |
181 | .SetIsStateful() |
182 | .SetShapeFn(shape_inference::NoOutputs); |
183 | |
184 | Status CreateAssignShapeFn(InferenceContext* c) { |
185 | std::vector<ShapeAndType> handle_shape_and_type; |
186 | TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle( |
187 | c, &handle_shape_and_type)); |
188 | |
189 | ShapeHandle value_shape = c->input(1); |
190 | ShapeHandle unused; |
191 | TF_RETURN_IF_ERROR( |
192 | c->Merge(handle_shape_and_type[0].shape, value_shape, &unused)); |
193 | |
194 | if (handle_shape_and_type[0].dtype == DT_VARIANT && |
195 | handle_shape_and_type.size() > 1 && |
196 | c->input_handle_shapes_and_types(1) != nullptr) { |
197 | auto* value_handle_shape_and_type = c->input_handle_shapes_and_types(1); |
198 | if (value_handle_shape_and_type->size() != |
199 | handle_shape_and_type.size() - 1) { |
200 | return errors::InvalidArgument( |
201 | "Incompatible handle variant shape_and_type size and input " |
202 | "shape_and_type size: " , |
203 | handle_shape_and_type.size() - 1, " vs. " , |
204 | value_handle_shape_and_type->size()); |
205 | } |
206 | } |
207 | return OkStatus(); |
208 | } |
209 | |
210 | REGISTER_OP("AssignVariableOp" ) |
211 | .Input("resource: resource" ) |
212 | .Input("value: dtype" ) |
213 | .Attr("dtype: type" ) |
214 | .Attr("validate_shape: bool = false" ) |
215 | .SetShapeFn(CreateAssignShapeFn); |
216 | |
217 | REGISTER_OP("AssignAddVariableOp" ) |
218 | .Input("resource: resource" ) |
219 | .Input("value: dtype" ) |
220 | .Attr("dtype: type" ) |
221 | .SetShapeFn(CreateAssignShapeFn); |
222 | |
223 | REGISTER_OP("AssignSubVariableOp" ) |
224 | .Input("resource: resource" ) |
225 | .Input("value: dtype" ) |
226 | .Attr("dtype: type" ) |
227 | .SetShapeFn(CreateAssignShapeFn); |
228 | |
229 | REGISTER_OP("VarIsInitializedOp" ) |
230 | .Input("resource: resource" ) |
231 | .Output("is_initialized: bool" ) |
232 | .SetShapeFn(tensorflow::shape_inference::ScalarShape); |
233 | |
234 | Status VariableShapeShapeFn(InferenceContext* c) { |
235 | auto* handle_data = c->input_handle_shapes_and_types(0); |
236 | if (handle_data == nullptr || handle_data->empty()) { |
237 | c->set_output(0, c->Vector(c->UnknownDim())); |
238 | return OkStatus(); |
239 | } |
240 | ShapeHandle var_shape = (*handle_data)[0].shape; |
241 | int64_t rank = c->RankKnown(var_shape) ? c->Rank(var_shape) |
242 | : InferenceContext::kUnknownDim; |
243 | c->set_output(0, c->Vector(rank)); |
244 | return OkStatus(); |
245 | } |
246 | |
247 | REGISTER_OP("VariableShape" ) |
248 | .Input("input: resource" ) |
249 | .Output("output: out_type" ) |
250 | .Attr("out_type: {int32, int64} = DT_INT32" ) |
251 | .SetShapeFn(VariableShapeShapeFn); |
252 | |
253 | REGISTER_OP("ResourceGather" ) |
254 | .Input("resource: resource" ) |
255 | .Input("indices: Tindices" ) |
256 | .Attr("batch_dims: int = 0" ) |
257 | .Attr("validate_indices: bool = true" ) |
258 | .Output("output: dtype" ) |
259 | .Attr("dtype: type" ) |
260 | .Attr("Tindices: {int32,int64}" ) |
261 | .SetShapeFn([](InferenceContext* c) { |
262 | std::vector<ShapeAndType> handle_shape_and_type; |
263 | TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle( |
264 | c, &handle_shape_and_type)); |
265 | |
266 | ShapeHandle indices_shape = c->input(1); |
267 | |
268 | ShapeHandle unused; |
269 | int32_t batch_dims; |
270 | TF_RETURN_IF_ERROR(c->GetAttr("batch_dims" , &batch_dims)); |
271 | if (batch_dims < 0) |
272 | return errors::InvalidArgument("batch_dims is negative (" , batch_dims, |
273 | ")" ); |
274 | |
275 | TF_RETURN_IF_ERROR(c->WithRankAtLeast(handle_shape_and_type[0].shape, |
276 | batch_dims + 1, &unused)); |
277 | |
278 | TF_RETURN_IF_ERROR( |
279 | c->WithRankAtLeast(indices_shape, batch_dims, &unused)); |
280 | |
281 | ShapeHandle params_subshape1; |
282 | TF_RETURN_IF_ERROR(c->Subshape(handle_shape_and_type[0].shape, 0, |
283 | batch_dims, ¶ms_subshape1)); |
284 | |
285 | ShapeHandle params_subshape2; |
286 | TF_RETURN_IF_ERROR(c->Subshape(handle_shape_and_type[0].shape, |
287 | batch_dims + 1, ¶ms_subshape2)); |
288 | |
289 | ShapeHandle indices_subshape; |
290 | TF_RETURN_IF_ERROR( |
291 | c->Subshape(indices_shape, batch_dims, &indices_subshape)); |
292 | |
293 | // The out shape is params_shape[:batch_dims] + |
294 | // indices_shape[batch_dims:] + params_shape[batch_dims+1:]. |
295 | ShapeHandle out; |
296 | TF_RETURN_IF_ERROR( |
297 | c->Concatenate(params_subshape1, indices_subshape, &out)); |
298 | TF_RETURN_IF_ERROR(c->Concatenate(out, params_subshape2, &out)); |
299 | |
300 | c->set_output(0, out); |
301 | if (handle_shape_and_type[0].dtype == DT_VARIANT && |
302 | !handle_shape_and_type.empty()) { |
303 | std::vector<ShapeAndType> variant_shape_and_type; |
304 | std::copy(handle_shape_and_type.begin() + 1, |
305 | handle_shape_and_type.end(), |
306 | std::back_inserter(variant_shape_and_type)); |
307 | c->set_output_handle_shapes_and_types(0, variant_shape_and_type); |
308 | } |
309 | return OkStatus(); |
310 | }); |
311 | |
312 | REGISTER_OP("ResourceGatherNd" ) |
313 | .Input("resource: resource" ) |
314 | .Input("indices: Tindices" ) |
315 | .Output("output: dtype" ) |
316 | .Attr("dtype: type" ) |
317 | .Attr("Tindices: {int32,int64}" ) |
318 | .SetShapeFn(shape_inference::GatherNdShape); |
319 | |
320 | namespace { |
321 | |
322 | Status ResourceScatterUpdateShape(InferenceContext* c) { |
323 | std::vector<ShapeAndType> handle_shape_and_type; |
324 | TF_RETURN_IF_ERROR(shape_inference::ValidateVariableResourceHandle( |
325 | c, &handle_shape_and_type)); |
326 | ShapeHandle var_shape = handle_shape_and_type[0].shape; |
327 | ShapeHandle indices_shape = c->input(1); |
328 | |
329 | ShapeHandle unused_updates_shape; |
330 | ShapeHandle concat; |
331 | ShapeHandle var_subshape; |
332 | TF_RETURN_IF_ERROR(c->Subshape(var_shape, 1, &var_subshape)); |
333 | TF_RETURN_IF_ERROR(c->Concatenate(indices_shape, var_subshape, &concat)); |
334 | TF_RETURN_IF_ERROR( |
335 | InferenceContext::Rank(c->input(2)) == 0 |
336 | ? OkStatus() |
337 | : c->Merge(c->input(2), concat, &unused_updates_shape)); |
338 | if (handle_shape_and_type[0].dtype == DT_VARIANT && |
339 | handle_shape_and_type.size() > 1 && |
340 | c->input_handle_shapes_and_types(2) != nullptr) { |
341 | auto* value_handle_shape_and_type = c->input_handle_shapes_and_types(2); |
342 | if (value_handle_shape_and_type->size() != |
343 | handle_shape_and_type.size() - 1) { |
344 | return errors::InvalidArgument( |
345 | "Incompatible handle variant shape_and_type size and input " |
346 | "shape_and_type size: " , |
347 | handle_shape_and_type.size() - 1, " vs. " , |
348 | value_handle_shape_and_type->size()); |
349 | } |
350 | } |
351 | return OkStatus(); |
352 | } |
353 | |
354 | } // namespace |
355 | |
356 | REGISTER_OP("ResourceScatterAdd" ) |
357 | .Input("resource: resource" ) |
358 | .Input("indices: Tindices" ) |
359 | .Input("updates: dtype" ) |
360 | .Attr("dtype: numbertype" ) |
361 | .Attr("Tindices: {int32, int64}" ) |
362 | .SetShapeFn(ResourceScatterUpdateShape); |
363 | |
364 | REGISTER_OP("ResourceScatterSub" ) |
365 | .Input("resource: resource" ) |
366 | .Input("indices: Tindices" ) |
367 | .Input("updates: dtype" ) |
368 | .Attr("dtype: numbertype" ) |
369 | .Attr("Tindices: {int32, int64}" ) |
370 | .SetShapeFn(ResourceScatterUpdateShape); |
371 | |
372 | REGISTER_OP("ResourceScatterMul" ) |
373 | .Input("resource: resource" ) |
374 | .Input("indices: Tindices" ) |
375 | .Input("updates: dtype" ) |
376 | .Attr("dtype: numbertype" ) |
377 | .Attr("Tindices: {int32, int64}" ) |
378 | .SetShapeFn(ResourceScatterUpdateShape); |
379 | |
380 | REGISTER_OP("ResourceScatterDiv" ) |
381 | .Input("resource: resource" ) |
382 | .Input("indices: Tindices" ) |
383 | .Input("updates: dtype" ) |
384 | .Attr("dtype: numbertype" ) |
385 | .Attr("Tindices: {int32, int64}" ) |
386 | .SetShapeFn(ResourceScatterUpdateShape); |
387 | |
388 | REGISTER_OP("ResourceScatterMin" ) |
389 | .Input("resource: resource" ) |
390 | .Input("indices: Tindices" ) |
391 | .Input("updates: dtype" ) |
392 | .Attr("dtype: numbertype" ) |
393 | .Attr("Tindices: {int32, int64}" ) |
394 | .SetShapeFn(ResourceScatterUpdateShape); |
395 | |
396 | REGISTER_OP("ResourceScatterMax" ) |
397 | .Input("resource: resource" ) |
398 | .Input("indices: Tindices" ) |
399 | .Input("updates: dtype" ) |
400 | .Attr("dtype: numbertype" ) |
401 | .Attr("Tindices: {int32, int64}" ) |
402 | .SetShapeFn(ResourceScatterUpdateShape); |
403 | |
404 | REGISTER_OP("ResourceScatterUpdate" ) |
405 | .Input("resource: resource" ) |
406 | .Input("indices: Tindices" ) |
407 | .Input("updates: dtype" ) |
408 | .Attr("dtype: type" ) |
409 | .Attr("Tindices: {int32, int64}" ) |
410 | .SetShapeFn(ResourceScatterUpdateShape); |
411 | |
412 | REGISTER_OP("MutexV2" ) |
413 | .Attr("container: string = ''" ) |
414 | .Attr("shared_name: string = ''" ) |
415 | .Output("resource: resource" ) |
416 | .SetIsStateful() |
417 | .SetShapeFn([](InferenceContext* c) { |
418 | c->set_output(0, c->Scalar()); |
419 | return OkStatus(); |
420 | }); |
421 | |
422 | REGISTER_OP("MutexLock" ) |
423 | .Input("mutex: resource" ) |
424 | .Output("mutex_lock: variant" ) |
425 | .SetIsStateful() |
426 | .SetTypeConstructor(full_type::Nullary(TFT_MUTEX_LOCK)) |
427 | .SetShapeFn([](InferenceContext* c) { |
428 | c->set_output(0, c->Scalar()); |
429 | return OkStatus(); |
430 | }); |
431 | |
432 | REGISTER_OP("ConsumeMutexLock" ) |
433 | .Input("mutex_lock: variant" ) |
434 | .SetIsStateful() |
435 | .SetShapeFn([](InferenceContext* c) { return OkStatus(); }); |
436 | |
437 | REGISTER_OP("DisableCopyOnRead" ) |
438 | .Input("resource: resource" ) |
439 | .SetShapeFn(shape_inference::NoOutputs); |
440 | |
441 | } // namespace tensorflow |
442 | |