1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | #include "tensorflow/core/util/saved_tensor_slice_util.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | using shape_inference::DimensionHandle; |
24 | using shape_inference::InferenceContext; |
25 | using shape_inference::ShapeHandle; |
26 | |
27 | namespace { |
28 | |
29 | Status ScalarInputsAndOutputs(InferenceContext* c) { |
30 | ShapeHandle unused; |
31 | for (int i = 0; i < c->num_inputs(); ++i) { |
32 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused)); |
33 | } |
34 | for (int i = 0; i < c->num_outputs(); ++i) { |
35 | c->set_output(i, c->Scalar()); |
36 | } |
37 | return OkStatus(); |
38 | } |
39 | |
40 | Status TwoElementVectorAndScalarOutputs(InferenceContext* c) { |
41 | ShapeHandle handle; |
42 | DimensionHandle unused_handle; |
43 | for (int i = 0; i < c->num_inputs(); ++i) { |
44 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle)); |
45 | TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle)); |
46 | } |
47 | for (int i = 0; i < c->num_outputs(); ++i) { |
48 | c->set_output(i, c->Scalar()); |
49 | } |
50 | return OkStatus(); |
51 | } |
52 | |
53 | Status TwoElementOutput(InferenceContext* c) { |
54 | c->set_output(0, c->Vector(2)); |
55 | return OkStatus(); |
56 | } |
57 | |
58 | } // namespace |
59 | |
60 | REGISTER_OP("SaveV2" ) |
61 | .Input("prefix: string" ) |
62 | .Input("tensor_names: string" ) |
63 | .Input("shape_and_slices: string" ) |
64 | .Input("tensors: dtypes" ) |
65 | .Attr("dtypes: list(type)" ) |
66 | .SetIsStateful() |
67 | .SetShapeFn([](InferenceContext* c) { |
68 | ShapeHandle unused; |
69 | ShapeHandle s; |
70 | DimensionHandle unused_dim; |
71 | |
72 | // Validate prefix. |
73 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
74 | |
75 | // Validate tensor_names and shapes_and_slices. |
76 | for (int i = 1; i <= 2; ++i) { |
77 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s)); |
78 | TF_RETURN_IF_ERROR( |
79 | c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim)); |
80 | } |
81 | // TODO(mrry): Attempt to parse the shapes_and_slices values and use |
82 | // them to constrain the shape of the remaining inputs. |
83 | return OkStatus(); |
84 | }); |
85 | |
86 | REGISTER_OP("RestoreV2" ) |
87 | .Input("prefix: string" ) |
88 | .Input("tensor_names: string" ) |
89 | .Input("shape_and_slices: string" ) |
90 | .Output("tensors: dtypes" ) |
91 | .Attr("dtypes: list(type)" ) |
92 | .SetIsStateful() |
93 | .SetShapeFn([](InferenceContext* c) { |
94 | ShapeHandle shape0, shape1, shape2; |
95 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &shape0)); |
96 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1)); |
97 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2)); |
98 | TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0)); |
99 | |
100 | // Attempt to infer output shapes from its shape_and_slice input. |
101 | const Tensor* shape_and_slices_tensor = c->input_tensor(2); |
102 | if (shape_and_slices_tensor) { |
103 | if (shape_and_slices_tensor->dtype() != DT_STRING) { |
104 | return errors::InvalidArgument( |
105 | "Expected an input tensor of type string." ); |
106 | } |
107 | |
108 | const auto& shape_and_slices_flat = |
109 | shape_and_slices_tensor->flat<tstring>(); |
110 | if (shape_and_slices_flat.size() != c->num_outputs()) { |
111 | return errors::InvalidArgument( |
112 | "The number of shape_and_slice doesn't match tensor outputs." ); |
113 | } |
114 | for (int i = 0; i < shape_and_slices_flat.size(); ++i) { |
115 | const string& shape_and_slice = shape_and_slices_flat(i); |
116 | if (shape_and_slice.empty()) { |
117 | c->set_output(i, c->UnknownShape()); |
118 | continue; |
119 | } |
120 | TensorShape parsed_full_shape; |
121 | TensorSlice parsed_slice; |
122 | TensorShape parsed_slice_shape; |
123 | TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice( |
124 | shape_and_slice, &parsed_full_shape, &parsed_slice, |
125 | &parsed_slice_shape)); |
126 | ShapeHandle shape_handle; |
127 | TF_RETURN_IF_ERROR( |
128 | c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle)); |
129 | c->set_output(i, shape_handle); |
130 | } |
131 | return OkStatus(); |
132 | } else { |
133 | return UnknownShape(c); |
134 | } |
135 | }); |
136 | |
137 | REGISTER_OP("MergeV2Checkpoints" ) |
138 | .Input("checkpoint_prefixes: string" ) |
139 | .Input("destination_prefix: string" ) |
140 | .Attr("delete_old_dirs: bool = true" ) |
141 | .Attr("allow_missing_files: bool = false" ) |
142 | .SetIsStateful() |
143 | .SetShapeFn([](InferenceContext* c) { |
144 | ShapeHandle unused; |
145 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
146 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
147 | return OkStatus(); |
148 | }); |
149 | |
150 | REGISTER_OP("Save" ) |
151 | .Input("filename: string" ) |
152 | .Input("tensor_names: string" ) |
153 | .Input("data: T" ) |
154 | .Attr("T: list(type)" ) |
155 | .SetIsStateful() |
156 | .SetShapeFn([](InferenceContext* c) { |
157 | ShapeHandle unused; |
158 | ShapeHandle s; |
159 | DimensionHandle unused_dim; |
160 | |
161 | // Validate filename. |
162 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
163 | |
164 | // Validate tensor_names. |
165 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &s)); |
166 | TF_RETURN_IF_ERROR( |
167 | c->WithValue(c->Dim(s, 0), c->num_inputs() - 2, &unused_dim)); |
168 | |
169 | return OkStatus(); |
170 | }); |
171 | |
172 | REGISTER_OP("SaveSlices" ) |
173 | .Input("filename: string" ) |
174 | .Input("tensor_names: string" ) |
175 | .Input("shapes_and_slices: string" ) |
176 | .Input("data: T" ) |
177 | .Attr("T: list(type)" ) |
178 | .SetIsStateful() |
179 | .SetShapeFn([](InferenceContext* c) { |
180 | ShapeHandle unused; |
181 | ShapeHandle s; |
182 | DimensionHandle unused_dim; |
183 | |
184 | // Validate filename. |
185 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
186 | |
187 | // Validate tensor_names and unused_shapes_and_slices. |
188 | for (int i = 1; i <= 2; ++i) { |
189 | TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s)); |
190 | TF_RETURN_IF_ERROR( |
191 | c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim)); |
192 | } |
193 | // TODO(mrry): Attempt to parse the shapes_and_slices values and use |
194 | // them to constrain the shape of the remaining inputs. |
195 | return OkStatus(); |
196 | }); |
197 | |
198 | REGISTER_OP("Restore" ) |
199 | .Input("file_pattern: string" ) |
200 | .Input("tensor_name: string" ) |
201 | .Output("tensor: dt" ) |
202 | .Attr("dt: type" ) |
203 | .Attr("preferred_shard: int = -1" ) |
204 | .SetIsStateful() |
205 | .SetShapeFn([](InferenceContext* c) { |
206 | ShapeHandle unused; |
207 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
208 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
209 | c->set_output(0, c->UnknownShape()); |
210 | return OkStatus(); |
211 | }); |
212 | |
213 | REGISTER_OP("RestoreSlice" ) |
214 | .Input("file_pattern: string" ) |
215 | .Input("tensor_name: string" ) |
216 | .Input("shape_and_slice: string" ) |
217 | .Output("tensor: dt" ) |
218 | .Attr("dt: type" ) |
219 | .Attr("preferred_shard: int = -1" ) |
220 | .SetIsStateful() |
221 | .SetShapeFn([](InferenceContext* c) { |
222 | ShapeHandle unused; |
223 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
224 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
225 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
226 | |
227 | // Attempt to infer output shapes from its shape_and_slice input. |
228 | const Tensor* shape_and_slices_tensor = c->input_tensor(2); |
229 | if (shape_and_slices_tensor) { |
230 | const auto& shape_and_slice = |
231 | shape_and_slices_tensor->flat<tstring>()(0); |
232 | if (shape_and_slice.empty()) { |
233 | c->set_output(0, c->UnknownShape()); |
234 | } else { |
235 | TensorShape parsed_full_shape; |
236 | TensorSlice parsed_slice; |
237 | TensorShape parsed_slice_shape; |
238 | TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice( |
239 | shape_and_slice, &parsed_full_shape, &parsed_slice, |
240 | &parsed_slice_shape)); |
241 | ShapeHandle shape_handle; |
242 | TF_RETURN_IF_ERROR( |
243 | c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle)); |
244 | c->set_output(0, shape_handle); |
245 | } |
246 | } else { |
247 | c->set_output(0, c->UnknownShape()); |
248 | } |
249 | return OkStatus(); |
250 | }); |
251 | |
252 | REGISTER_OP("ShardedFilename" ) |
253 | .Input("basename: string" ) |
254 | .Input("shard: int32" ) |
255 | .Input("num_shards: int32" ) |
256 | .Output("filename: string" ) |
257 | .SetShapeFn(ScalarInputsAndOutputs); |
258 | |
259 | REGISTER_OP("ShardedFilespec" ) |
260 | .Input("basename: string" ) |
261 | .Input("num_shards: int32" ) |
262 | .Output("filename: string" ) |
263 | .SetShapeFn(ScalarInputsAndOutputs); |
264 | |
265 | // Reader source ops ---------------------------------------------------------- |
266 | |
267 | REGISTER_OP("WholeFileReader" ) |
268 | .Output("reader_handle: Ref(string)" ) |
269 | .Attr("container: string = ''" ) |
270 | .Attr("shared_name: string = ''" ) |
271 | .SetIsStateful() |
272 | .SetShapeFn(TwoElementOutput); |
273 | |
274 | REGISTER_OP("WholeFileReaderV2" ) |
275 | .Output("reader_handle: resource" ) |
276 | .Attr("container: string = ''" ) |
277 | .Attr("shared_name: string = ''" ) |
278 | .SetIsStateful() |
279 | .SetShapeFn(shape_inference::ScalarShape); |
280 | |
281 | REGISTER_OP("TextLineReader" ) |
282 | .Output("reader_handle: Ref(string)" ) |
283 | .Attr("skip_header_lines: int = 0" ) |
284 | .Attr("container: string = ''" ) |
285 | .Attr("shared_name: string = ''" ) |
286 | .SetIsStateful() |
287 | .SetShapeFn(TwoElementOutput) |
288 | .Deprecated(26, "Use TextLineReaderV2" ); |
289 | |
290 | REGISTER_OP("TextLineReaderV2" ) |
291 | .Output("reader_handle: resource" ) |
292 | .Attr("skip_header_lines: int = 0" ) |
293 | .Attr("container: string = ''" ) |
294 | .Attr("shared_name: string = ''" ) |
295 | .SetIsStateful() |
296 | .SetShapeFn(shape_inference::ScalarShape); |
297 | |
298 | REGISTER_OP("FixedLengthRecordReader" ) |
299 | .Output("reader_handle: Ref(string)" ) |
300 | .Attr("header_bytes: int = 0" ) |
301 | .Attr("record_bytes: int" ) |
302 | .Attr("footer_bytes: int = 0" ) |
303 | .Attr("hop_bytes: int = 0" ) |
304 | .Attr("container: string = ''" ) |
305 | .Attr("shared_name: string = ''" ) |
306 | .SetIsStateful() |
307 | .SetShapeFn(TwoElementOutput) |
308 | .Deprecated(26, "Use FixedLengthRecordReaderV2" ); |
309 | |
310 | REGISTER_OP("FixedLengthRecordReaderV2" ) |
311 | .Output("reader_handle: resource" ) |
312 | .Attr("header_bytes: int = 0" ) |
313 | .Attr("record_bytes: int" ) |
314 | .Attr("footer_bytes: int = 0" ) |
315 | .Attr("hop_bytes: int = 0" ) |
316 | .Attr("container: string = ''" ) |
317 | .Attr("shared_name: string = ''" ) |
318 | .Attr("encoding: string = ''" ) |
319 | .SetIsStateful() |
320 | .SetShapeFn(shape_inference::ScalarShape); |
321 | |
322 | REGISTER_OP("TFRecordReader" ) |
323 | .Output("reader_handle: Ref(string)" ) |
324 | .Attr("container: string = ''" ) |
325 | .Attr("shared_name: string = ''" ) |
326 | .Attr("compression_type: string = ''" ) |
327 | .SetIsStateful() |
328 | .SetShapeFn(TwoElementOutput) |
329 | .Deprecated(26, "Use TFRecordReaderV2" ); |
330 | |
331 | REGISTER_OP("TFRecordReaderV2" ) |
332 | .Output("reader_handle: resource" ) |
333 | .Attr("container: string = ''" ) |
334 | .Attr("shared_name: string = ''" ) |
335 | .Attr("compression_type: string = ''" ) |
336 | .SetIsStateful() |
337 | .SetShapeFn(shape_inference::ScalarShape); |
338 | |
339 | REGISTER_OP("LMDBReader" ) |
340 | .Output("reader_handle: Ref(string)" ) |
341 | .Attr("container: string = ''" ) |
342 | .Attr("shared_name: string = ''" ) |
343 | .SetIsStateful() |
344 | .SetShapeFn(TwoElementOutput); |
345 | |
346 | REGISTER_OP("IdentityReader" ) |
347 | .Output("reader_handle: Ref(string)" ) |
348 | .Attr("container: string = ''" ) |
349 | .Attr("shared_name: string = ''" ) |
350 | .SetIsStateful() |
351 | .SetShapeFn(TwoElementOutput) |
352 | .Deprecated(26, "Use IdentityReaderV2" ); |
353 | |
354 | REGISTER_OP("IdentityReaderV2" ) |
355 | .Output("reader_handle: resource" ) |
356 | .Attr("container: string = ''" ) |
357 | .Attr("shared_name: string = ''" ) |
358 | .SetIsStateful() |
359 | .SetShapeFn(shape_inference::ScalarShape); |
360 | |
361 | // Ops that operate on Readers ------------------------------------------------ |
362 | |
363 | REGISTER_OP("ReaderRead" ) |
364 | .Input("reader_handle: Ref(string)" ) |
365 | .Input("queue_handle: Ref(string)" ) |
366 | .Output("key: string" ) |
367 | .Output("value: string" ) |
368 | .SetShapeFn(TwoElementVectorAndScalarOutputs); |
369 | |
370 | REGISTER_OP("ReaderReadV2" ) |
371 | .Input("reader_handle: resource" ) |
372 | .Input("queue_handle: resource" ) |
373 | .Output("key: string" ) |
374 | .Output("value: string" ) |
375 | .SetShapeFn(ScalarInputsAndOutputs); |
376 | |
377 | REGISTER_OP("ReaderReadUpTo" ) |
378 | .Input("reader_handle: Ref(string)" ) |
379 | .Input("queue_handle: Ref(string)" ) |
380 | .Input("num_records: int64" ) |
381 | .Output("keys: string" ) |
382 | .Output("values: string" ) |
383 | .SetShapeFn([](InferenceContext* c) { |
384 | ShapeHandle unused; |
385 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
386 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); |
387 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
388 | ShapeHandle out = c->Vector(InferenceContext::kUnknownDim); |
389 | c->set_output(0, out); |
390 | c->set_output(1, out); |
391 | return OkStatus(); |
392 | }); |
393 | |
394 | REGISTER_OP("ReaderReadUpToV2" ) |
395 | .Input("reader_handle: resource" ) |
396 | .Input("queue_handle: resource" ) |
397 | .Input("num_records: int64" ) |
398 | .Output("keys: string" ) |
399 | .Output("values: string" ) |
400 | .SetShapeFn([](InferenceContext* c) { |
401 | ShapeHandle unused; |
402 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
403 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
404 | TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); |
405 | ShapeHandle out = c->Vector(InferenceContext::kUnknownDim); |
406 | c->set_output(0, out); |
407 | c->set_output(1, out); |
408 | return OkStatus(); |
409 | }); |
410 | |
411 | REGISTER_OP("ReaderNumRecordsProduced" ) |
412 | .Input("reader_handle: Ref(string)" ) |
413 | .Output("records_produced: int64" ) |
414 | .SetShapeFn(TwoElementVectorAndScalarOutputs); |
415 | |
416 | REGISTER_OP("ReaderNumRecordsProducedV2" ) |
417 | .Input("reader_handle: resource" ) |
418 | .Output("records_produced: int64" ) |
419 | .SetShapeFn(ScalarInputsAndOutputs); |
420 | |
421 | REGISTER_OP("ReaderNumWorkUnitsCompleted" ) |
422 | .Input("reader_handle: Ref(string)" ) |
423 | .Output("units_completed: int64" ) |
424 | .SetShapeFn(TwoElementVectorAndScalarOutputs); |
425 | |
426 | REGISTER_OP("ReaderNumWorkUnitsCompletedV2" ) |
427 | .Input("reader_handle: resource" ) |
428 | .Output("units_completed: int64" ) |
429 | .SetShapeFn(ScalarInputsAndOutputs); |
430 | |
431 | REGISTER_OP("ReaderSerializeState" ) |
432 | .Input("reader_handle: Ref(string)" ) |
433 | .Output("state: string" ) |
434 | .SetShapeFn(TwoElementVectorAndScalarOutputs); |
435 | |
436 | REGISTER_OP("ReaderSerializeStateV2" ) |
437 | .Input("reader_handle: resource" ) |
438 | .Output("state: string" ) |
439 | .SetShapeFn(ScalarInputsAndOutputs); |
440 | |
441 | REGISTER_OP("ReaderRestoreState" ) |
442 | .Input("reader_handle: Ref(string)" ) |
443 | .Input("state: string" ) |
444 | .SetShapeFn([](InferenceContext* c) { |
445 | ShapeHandle unused; |
446 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); |
447 | DimensionHandle unused_handle; |
448 | TF_RETURN_IF_ERROR( |
449 | c->WithValue(c->Dim(c->input(0), 0), 2, &unused_handle)); |
450 | |
451 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
452 | return OkStatus(); |
453 | }); |
454 | |
455 | REGISTER_OP("ReaderRestoreStateV2" ) |
456 | .Input("reader_handle: resource" ) |
457 | .Input("state: string" ) |
458 | .SetShapeFn([](InferenceContext* c) { |
459 | ShapeHandle unused; |
460 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
461 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
462 | return OkStatus(); |
463 | }); |
464 | |
465 | REGISTER_OP("ReaderReset" ) |
466 | .Input("reader_handle: Ref(string)" ) |
467 | .SetShapeFn(TwoElementVectorAndScalarOutputs); |
468 | |
469 | REGISTER_OP("ReaderResetV2" ) |
470 | .Input("reader_handle: resource" ) |
471 | .SetShapeFn(ScalarInputsAndOutputs); |
472 | |
473 | // Other input Ops ---------------------------------------------------------- |
474 | |
475 | REGISTER_OP("ReadFile" ) |
476 | .Input("filename: string" ) |
477 | .Output("contents: string" ) |
478 | .SetShapeFn(ScalarInputsAndOutputs); |
479 | |
480 | REGISTER_OP("WriteFile" ) |
481 | .Input("filename: string" ) |
482 | .Input("contents: string" ) |
483 | .SetIsStateful() |
484 | .SetShapeFn([](InferenceContext* c) { |
485 | ShapeHandle unused; |
486 | TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); |
487 | TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); |
488 | return OkStatus(); |
489 | }); |
490 | |
491 | REGISTER_OP("MatchingFiles" ) |
492 | .Input("pattern: string" ) |
493 | .Output("filenames: string" ) |
494 | .SetShapeFn([](InferenceContext* c) { |
495 | ShapeHandle unused; |
496 | TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); |
497 | c->set_output(0, c->Vector(InferenceContext::kUnknownDim)); |
498 | return OkStatus(); |
499 | }); |
500 | |
501 | } // namespace tensorflow |
502 | |