1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/common_shape_fns.h"
17#include "tensorflow/core/framework/op.h"
18#include "tensorflow/core/framework/shape_inference.h"
19#include "tensorflow/core/util/saved_tensor_slice_util.h"
20
21namespace tensorflow {
22
23using shape_inference::DimensionHandle;
24using shape_inference::InferenceContext;
25using shape_inference::ShapeHandle;
26
27namespace {
28
29Status ScalarInputsAndOutputs(InferenceContext* c) {
30 ShapeHandle unused;
31 for (int i = 0; i < c->num_inputs(); ++i) {
32 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 0, &unused));
33 }
34 for (int i = 0; i < c->num_outputs(); ++i) {
35 c->set_output(i, c->Scalar());
36 }
37 return OkStatus();
38}
39
40Status TwoElementVectorAndScalarOutputs(InferenceContext* c) {
41 ShapeHandle handle;
42 DimensionHandle unused_handle;
43 for (int i = 0; i < c->num_inputs(); ++i) {
44 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &handle));
45 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(handle, 0), 2, &unused_handle));
46 }
47 for (int i = 0; i < c->num_outputs(); ++i) {
48 c->set_output(i, c->Scalar());
49 }
50 return OkStatus();
51}
52
53Status TwoElementOutput(InferenceContext* c) {
54 c->set_output(0, c->Vector(2));
55 return OkStatus();
56}
57
58} // namespace
59
60REGISTER_OP("SaveV2")
61 .Input("prefix: string")
62 .Input("tensor_names: string")
63 .Input("shape_and_slices: string")
64 .Input("tensors: dtypes")
65 .Attr("dtypes: list(type)")
66 .SetIsStateful()
67 .SetShapeFn([](InferenceContext* c) {
68 ShapeHandle unused;
69 ShapeHandle s;
70 DimensionHandle unused_dim;
71
72 // Validate prefix.
73 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
74
75 // Validate tensor_names and shapes_and_slices.
76 for (int i = 1; i <= 2; ++i) {
77 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
78 TF_RETURN_IF_ERROR(
79 c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
80 }
81 // TODO(mrry): Attempt to parse the shapes_and_slices values and use
82 // them to constrain the shape of the remaining inputs.
83 return OkStatus();
84 });
85
86REGISTER_OP("RestoreV2")
87 .Input("prefix: string")
88 .Input("tensor_names: string")
89 .Input("shape_and_slices: string")
90 .Output("tensors: dtypes")
91 .Attr("dtypes: list(type)")
92 .SetIsStateful()
93 .SetShapeFn([](InferenceContext* c) {
94 ShapeHandle shape0, shape1, shape2;
95 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &shape0));
96 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &shape1));
97 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &shape2));
98 TF_RETURN_IF_ERROR(c->Merge(shape1, shape2, &shape0));
99
100 // Attempt to infer output shapes from its shape_and_slice input.
101 const Tensor* shape_and_slices_tensor = c->input_tensor(2);
102 if (shape_and_slices_tensor) {
103 if (shape_and_slices_tensor->dtype() != DT_STRING) {
104 return errors::InvalidArgument(
105 "Expected an input tensor of type string.");
106 }
107
108 const auto& shape_and_slices_flat =
109 shape_and_slices_tensor->flat<tstring>();
110 if (shape_and_slices_flat.size() != c->num_outputs()) {
111 return errors::InvalidArgument(
112 "The number of shape_and_slice doesn't match tensor outputs.");
113 }
114 for (int i = 0; i < shape_and_slices_flat.size(); ++i) {
115 const string& shape_and_slice = shape_and_slices_flat(i);
116 if (shape_and_slice.empty()) {
117 c->set_output(i, c->UnknownShape());
118 continue;
119 }
120 TensorShape parsed_full_shape;
121 TensorSlice parsed_slice;
122 TensorShape parsed_slice_shape;
123 TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
124 shape_and_slice, &parsed_full_shape, &parsed_slice,
125 &parsed_slice_shape));
126 ShapeHandle shape_handle;
127 TF_RETURN_IF_ERROR(
128 c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
129 c->set_output(i, shape_handle);
130 }
131 return OkStatus();
132 } else {
133 return UnknownShape(c);
134 }
135 });
136
137REGISTER_OP("MergeV2Checkpoints")
138 .Input("checkpoint_prefixes: string")
139 .Input("destination_prefix: string")
140 .Attr("delete_old_dirs: bool = true")
141 .Attr("allow_missing_files: bool = false")
142 .SetIsStateful()
143 .SetShapeFn([](InferenceContext* c) {
144 ShapeHandle unused;
145 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
146 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
147 return OkStatus();
148 });
149
150REGISTER_OP("Save")
151 .Input("filename: string")
152 .Input("tensor_names: string")
153 .Input("data: T")
154 .Attr("T: list(type)")
155 .SetIsStateful()
156 .SetShapeFn([](InferenceContext* c) {
157 ShapeHandle unused;
158 ShapeHandle s;
159 DimensionHandle unused_dim;
160
161 // Validate filename.
162 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
163
164 // Validate tensor_names.
165 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &s));
166 TF_RETURN_IF_ERROR(
167 c->WithValue(c->Dim(s, 0), c->num_inputs() - 2, &unused_dim));
168
169 return OkStatus();
170 });
171
172REGISTER_OP("SaveSlices")
173 .Input("filename: string")
174 .Input("tensor_names: string")
175 .Input("shapes_and_slices: string")
176 .Input("data: T")
177 .Attr("T: list(type)")
178 .SetIsStateful()
179 .SetShapeFn([](InferenceContext* c) {
180 ShapeHandle unused;
181 ShapeHandle s;
182 DimensionHandle unused_dim;
183
184 // Validate filename.
185 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
186
187 // Validate tensor_names and unused_shapes_and_slices.
188 for (int i = 1; i <= 2; ++i) {
189 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &s));
190 TF_RETURN_IF_ERROR(
191 c->WithValue(c->Dim(s, 0), c->num_inputs() - 3, &unused_dim));
192 }
193 // TODO(mrry): Attempt to parse the shapes_and_slices values and use
194 // them to constrain the shape of the remaining inputs.
195 return OkStatus();
196 });
197
198REGISTER_OP("Restore")
199 .Input("file_pattern: string")
200 .Input("tensor_name: string")
201 .Output("tensor: dt")
202 .Attr("dt: type")
203 .Attr("preferred_shard: int = -1")
204 .SetIsStateful()
205 .SetShapeFn([](InferenceContext* c) {
206 ShapeHandle unused;
207 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
208 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
209 c->set_output(0, c->UnknownShape());
210 return OkStatus();
211 });
212
213REGISTER_OP("RestoreSlice")
214 .Input("file_pattern: string")
215 .Input("tensor_name: string")
216 .Input("shape_and_slice: string")
217 .Output("tensor: dt")
218 .Attr("dt: type")
219 .Attr("preferred_shard: int = -1")
220 .SetIsStateful()
221 .SetShapeFn([](InferenceContext* c) {
222 ShapeHandle unused;
223 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
224 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
225 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
226
227 // Attempt to infer output shapes from its shape_and_slice input.
228 const Tensor* shape_and_slices_tensor = c->input_tensor(2);
229 if (shape_and_slices_tensor) {
230 const auto& shape_and_slice =
231 shape_and_slices_tensor->flat<tstring>()(0);
232 if (shape_and_slice.empty()) {
233 c->set_output(0, c->UnknownShape());
234 } else {
235 TensorShape parsed_full_shape;
236 TensorSlice parsed_slice;
237 TensorShape parsed_slice_shape;
238 TF_RETURN_IF_ERROR(checkpoint::ParseShapeAndSlice(
239 shape_and_slice, &parsed_full_shape, &parsed_slice,
240 &parsed_slice_shape));
241 ShapeHandle shape_handle;
242 TF_RETURN_IF_ERROR(
243 c->MakeShapeFromTensorShape(parsed_slice_shape, &shape_handle));
244 c->set_output(0, shape_handle);
245 }
246 } else {
247 c->set_output(0, c->UnknownShape());
248 }
249 return OkStatus();
250 });
251
252REGISTER_OP("ShardedFilename")
253 .Input("basename: string")
254 .Input("shard: int32")
255 .Input("num_shards: int32")
256 .Output("filename: string")
257 .SetShapeFn(ScalarInputsAndOutputs);
258
259REGISTER_OP("ShardedFilespec")
260 .Input("basename: string")
261 .Input("num_shards: int32")
262 .Output("filename: string")
263 .SetShapeFn(ScalarInputsAndOutputs);
264
265// Reader source ops ----------------------------------------------------------
266
267REGISTER_OP("WholeFileReader")
268 .Output("reader_handle: Ref(string)")
269 .Attr("container: string = ''")
270 .Attr("shared_name: string = ''")
271 .SetIsStateful()
272 .SetShapeFn(TwoElementOutput);
273
274REGISTER_OP("WholeFileReaderV2")
275 .Output("reader_handle: resource")
276 .Attr("container: string = ''")
277 .Attr("shared_name: string = ''")
278 .SetIsStateful()
279 .SetShapeFn(shape_inference::ScalarShape);
280
281REGISTER_OP("TextLineReader")
282 .Output("reader_handle: Ref(string)")
283 .Attr("skip_header_lines: int = 0")
284 .Attr("container: string = ''")
285 .Attr("shared_name: string = ''")
286 .SetIsStateful()
287 .SetShapeFn(TwoElementOutput)
288 .Deprecated(26, "Use TextLineReaderV2");
289
290REGISTER_OP("TextLineReaderV2")
291 .Output("reader_handle: resource")
292 .Attr("skip_header_lines: int = 0")
293 .Attr("container: string = ''")
294 .Attr("shared_name: string = ''")
295 .SetIsStateful()
296 .SetShapeFn(shape_inference::ScalarShape);
297
298REGISTER_OP("FixedLengthRecordReader")
299 .Output("reader_handle: Ref(string)")
300 .Attr("header_bytes: int = 0")
301 .Attr("record_bytes: int")
302 .Attr("footer_bytes: int = 0")
303 .Attr("hop_bytes: int = 0")
304 .Attr("container: string = ''")
305 .Attr("shared_name: string = ''")
306 .SetIsStateful()
307 .SetShapeFn(TwoElementOutput)
308 .Deprecated(26, "Use FixedLengthRecordReaderV2");
309
310REGISTER_OP("FixedLengthRecordReaderV2")
311 .Output("reader_handle: resource")
312 .Attr("header_bytes: int = 0")
313 .Attr("record_bytes: int")
314 .Attr("footer_bytes: int = 0")
315 .Attr("hop_bytes: int = 0")
316 .Attr("container: string = ''")
317 .Attr("shared_name: string = ''")
318 .Attr("encoding: string = ''")
319 .SetIsStateful()
320 .SetShapeFn(shape_inference::ScalarShape);
321
322REGISTER_OP("TFRecordReader")
323 .Output("reader_handle: Ref(string)")
324 .Attr("container: string = ''")
325 .Attr("shared_name: string = ''")
326 .Attr("compression_type: string = ''")
327 .SetIsStateful()
328 .SetShapeFn(TwoElementOutput)
329 .Deprecated(26, "Use TFRecordReaderV2");
330
331REGISTER_OP("TFRecordReaderV2")
332 .Output("reader_handle: resource")
333 .Attr("container: string = ''")
334 .Attr("shared_name: string = ''")
335 .Attr("compression_type: string = ''")
336 .SetIsStateful()
337 .SetShapeFn(shape_inference::ScalarShape);
338
339REGISTER_OP("LMDBReader")
340 .Output("reader_handle: Ref(string)")
341 .Attr("container: string = ''")
342 .Attr("shared_name: string = ''")
343 .SetIsStateful()
344 .SetShapeFn(TwoElementOutput);
345
346REGISTER_OP("IdentityReader")
347 .Output("reader_handle: Ref(string)")
348 .Attr("container: string = ''")
349 .Attr("shared_name: string = ''")
350 .SetIsStateful()
351 .SetShapeFn(TwoElementOutput)
352 .Deprecated(26, "Use IdentityReaderV2");
353
354REGISTER_OP("IdentityReaderV2")
355 .Output("reader_handle: resource")
356 .Attr("container: string = ''")
357 .Attr("shared_name: string = ''")
358 .SetIsStateful()
359 .SetShapeFn(shape_inference::ScalarShape);
360
361// Ops that operate on Readers ------------------------------------------------
362
363REGISTER_OP("ReaderRead")
364 .Input("reader_handle: Ref(string)")
365 .Input("queue_handle: Ref(string)")
366 .Output("key: string")
367 .Output("value: string")
368 .SetShapeFn(TwoElementVectorAndScalarOutputs);
369
370REGISTER_OP("ReaderReadV2")
371 .Input("reader_handle: resource")
372 .Input("queue_handle: resource")
373 .Output("key: string")
374 .Output("value: string")
375 .SetShapeFn(ScalarInputsAndOutputs);
376
377REGISTER_OP("ReaderReadUpTo")
378 .Input("reader_handle: Ref(string)")
379 .Input("queue_handle: Ref(string)")
380 .Input("num_records: int64")
381 .Output("keys: string")
382 .Output("values: string")
383 .SetShapeFn([](InferenceContext* c) {
384 ShapeHandle unused;
385 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
386 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
387 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
388 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
389 c->set_output(0, out);
390 c->set_output(1, out);
391 return OkStatus();
392 });
393
394REGISTER_OP("ReaderReadUpToV2")
395 .Input("reader_handle: resource")
396 .Input("queue_handle: resource")
397 .Input("num_records: int64")
398 .Output("keys: string")
399 .Output("values: string")
400 .SetShapeFn([](InferenceContext* c) {
401 ShapeHandle unused;
402 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
403 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
404 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
405 ShapeHandle out = c->Vector(InferenceContext::kUnknownDim);
406 c->set_output(0, out);
407 c->set_output(1, out);
408 return OkStatus();
409 });
410
411REGISTER_OP("ReaderNumRecordsProduced")
412 .Input("reader_handle: Ref(string)")
413 .Output("records_produced: int64")
414 .SetShapeFn(TwoElementVectorAndScalarOutputs);
415
416REGISTER_OP("ReaderNumRecordsProducedV2")
417 .Input("reader_handle: resource")
418 .Output("records_produced: int64")
419 .SetShapeFn(ScalarInputsAndOutputs);
420
421REGISTER_OP("ReaderNumWorkUnitsCompleted")
422 .Input("reader_handle: Ref(string)")
423 .Output("units_completed: int64")
424 .SetShapeFn(TwoElementVectorAndScalarOutputs);
425
426REGISTER_OP("ReaderNumWorkUnitsCompletedV2")
427 .Input("reader_handle: resource")
428 .Output("units_completed: int64")
429 .SetShapeFn(ScalarInputsAndOutputs);
430
431REGISTER_OP("ReaderSerializeState")
432 .Input("reader_handle: Ref(string)")
433 .Output("state: string")
434 .SetShapeFn(TwoElementVectorAndScalarOutputs);
435
436REGISTER_OP("ReaderSerializeStateV2")
437 .Input("reader_handle: resource")
438 .Output("state: string")
439 .SetShapeFn(ScalarInputsAndOutputs);
440
441REGISTER_OP("ReaderRestoreState")
442 .Input("reader_handle: Ref(string)")
443 .Input("state: string")
444 .SetShapeFn([](InferenceContext* c) {
445 ShapeHandle unused;
446 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
447 DimensionHandle unused_handle;
448 TF_RETURN_IF_ERROR(
449 c->WithValue(c->Dim(c->input(0), 0), 2, &unused_handle));
450
451 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
452 return OkStatus();
453 });
454
455REGISTER_OP("ReaderRestoreStateV2")
456 .Input("reader_handle: resource")
457 .Input("state: string")
458 .SetShapeFn([](InferenceContext* c) {
459 ShapeHandle unused;
460 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
461 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
462 return OkStatus();
463 });
464
465REGISTER_OP("ReaderReset")
466 .Input("reader_handle: Ref(string)")
467 .SetShapeFn(TwoElementVectorAndScalarOutputs);
468
469REGISTER_OP("ReaderResetV2")
470 .Input("reader_handle: resource")
471 .SetShapeFn(ScalarInputsAndOutputs);
472
473// Other input Ops ----------------------------------------------------------
474
475REGISTER_OP("ReadFile")
476 .Input("filename: string")
477 .Output("contents: string")
478 .SetShapeFn(ScalarInputsAndOutputs);
479
480REGISTER_OP("WriteFile")
481 .Input("filename: string")
482 .Input("contents: string")
483 .SetIsStateful()
484 .SetShapeFn([](InferenceContext* c) {
485 ShapeHandle unused;
486 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
487 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
488 return OkStatus();
489 });
490
491REGISTER_OP("MatchingFiles")
492 .Input("pattern: string")
493 .Output("filenames: string")
494 .SetShapeFn([](InferenceContext* c) {
495 ShapeHandle unused;
496 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
497 c->set_output(0, c->Vector(InferenceContext::kUnknownDim));
498 return OkStatus();
499 });
500
501} // namespace tensorflow
502