1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/io_ops.cc. |
17 | |
18 | #include <memory> |
19 | |
20 | #include "absl/strings/escaping.h" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/reader_base.h" |
23 | #include "tensorflow/core/framework/reader_base.pb.h" |
24 | #include "tensorflow/core/framework/reader_op_kernel.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/lib/core/errors.h" |
27 | #include "tensorflow/core/lib/io/buffered_inputstream.h" |
28 | #include "tensorflow/core/lib/io/path.h" |
29 | #include "tensorflow/core/lib/io/random_inputstream.h" |
30 | #include "tensorflow/core/lib/strings/str_util.h" |
31 | #include "tensorflow/core/lib/strings/strcat.h" |
32 | #include "tensorflow/core/platform/env.h" |
33 | #include "tensorflow/core/platform/protobuf.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | template <typename T> |
38 | static Status ReadEntireFile(Env* env, const string& filename, T* contents) { |
39 | std::unique_ptr<RandomAccessFile> file; |
40 | TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file)); |
41 | io::RandomAccessInputStream input_stream(file.get()); |
42 | io::BufferedInputStream in(&input_stream, 1 << 20); |
43 | TF_RETURN_IF_ERROR(in.ReadAll(contents)); |
44 | return OkStatus(); |
45 | } |
46 | |
47 | class WholeFileReader : public ReaderBase { |
48 | public: |
49 | WholeFileReader(Env* env, const string& node_name) |
50 | : ReaderBase(strings::StrCat("WholeFileReader '" , node_name, "'" )), |
51 | env_(env) {} |
52 | |
53 | Status ReadLocked(tstring* key, tstring* value, bool* produced, |
54 | bool* at_end) override { |
55 | *key = current_work(); |
56 | TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value)); |
57 | *produced = true; |
58 | *at_end = true; |
59 | return OkStatus(); |
60 | } |
61 | |
62 | // Stores state in a ReaderBaseState proto, since WholeFileReader has |
63 | // no additional state beyond ReaderBase. |
64 | Status SerializeStateLocked(tstring* state) override { |
65 | ReaderBaseState base_state; |
66 | SaveBaseState(&base_state); |
67 | SerializeToTString(base_state, state); |
68 | return OkStatus(); |
69 | } |
70 | |
71 | Status RestoreStateLocked(const tstring& state) override { |
72 | ReaderBaseState base_state; |
73 | if (!ParseProtoUnlimited(&base_state, state)) { |
74 | return errors::InvalidArgument("Could not parse state for " , name(), ": " , |
75 | absl::CEscape(state)); |
76 | } |
77 | TF_RETURN_IF_ERROR(RestoreBaseState(base_state)); |
78 | return OkStatus(); |
79 | } |
80 | |
81 | private: |
82 | Env* env_; |
83 | }; |
84 | |
85 | class WholeFileReaderOp : public ReaderOpKernel { |
86 | public: |
87 | explicit WholeFileReaderOp(OpKernelConstruction* context) |
88 | : ReaderOpKernel(context) { |
89 | Env* env = context->env(); |
90 | SetReaderFactory( |
91 | [this, env]() { return new WholeFileReader(env, name()); }); |
92 | } |
93 | }; |
94 | |
95 | REGISTER_KERNEL_BUILDER(Name("WholeFileReader" ).Device(DEVICE_CPU), |
96 | WholeFileReaderOp); |
97 | REGISTER_KERNEL_BUILDER(Name("WholeFileReaderV2" ).Device(DEVICE_CPU), |
98 | WholeFileReaderOp); |
99 | |
100 | class ReadFileOp : public OpKernel { |
101 | public: |
102 | using OpKernel::OpKernel; |
103 | void Compute(OpKernelContext* context) override { |
104 | const Tensor* input; |
105 | OP_REQUIRES_OK(context, context->input("filename" , &input)); |
106 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(input->shape()), |
107 | errors::InvalidArgument( |
108 | "Input filename tensor must be scalar, but had shape: " , |
109 | input->shape().DebugString())); |
110 | |
111 | Tensor* output = nullptr; |
112 | OP_REQUIRES_OK(context, context->allocate_output("contents" , |
113 | TensorShape({}), &output)); |
114 | OP_REQUIRES_OK(context, |
115 | ReadEntireFile(context->env(), input->scalar<tstring>()(), |
116 | &output->scalar<tstring>()())); |
117 | } |
118 | }; |
119 | |
120 | REGISTER_KERNEL_BUILDER(Name("ReadFile" ).Device(DEVICE_CPU), ReadFileOp); |
121 | |
122 | class WriteFileOp : public OpKernel { |
123 | public: |
124 | using OpKernel::OpKernel; |
125 | void Compute(OpKernelContext* context) override { |
126 | const Tensor* filename_input; |
127 | const Tensor* contents_input; |
128 | OP_REQUIRES_OK(context, context->input("filename" , &filename_input)); |
129 | OP_REQUIRES_OK(context, context->input("contents" , &contents_input)); |
130 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(filename_input->shape()), |
131 | errors::InvalidArgument( |
132 | "Input filename tensor must be scalar, but had shape: " , |
133 | filename_input->shape().DebugString())); |
134 | OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_input->shape()), |
135 | errors::InvalidArgument( |
136 | "Contents tensor must be scalar, but had shape: " , |
137 | contents_input->shape().DebugString())); |
138 | const string& filename = filename_input->scalar<tstring>()(); |
139 | const string dir(io::Dirname(filename)); |
140 | if (!context->env()->FileExists(dir).ok()) { |
141 | OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir)); |
142 | } |
143 | OP_REQUIRES_OK(context, |
144 | WriteStringToFile(context->env(), filename, |
145 | contents_input->scalar<tstring>()())); |
146 | } |
147 | }; |
148 | |
149 | REGISTER_KERNEL_BUILDER(Name("WriteFile" ).Device(DEVICE_CPU), WriteFileOp); |
150 | } // namespace tensorflow |
151 | |