1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/io_ops.cc.
17
18#include <memory>
19
20#include "absl/strings/escaping.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/reader_base.h"
23#include "tensorflow/core/framework/reader_base.pb.h"
24#include "tensorflow/core/framework/reader_op_kernel.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/lib/io/buffered_inputstream.h"
28#include "tensorflow/core/lib/io/path.h"
29#include "tensorflow/core/lib/io/random_inputstream.h"
30#include "tensorflow/core/lib/strings/str_util.h"
31#include "tensorflow/core/lib/strings/strcat.h"
32#include "tensorflow/core/platform/env.h"
33#include "tensorflow/core/platform/protobuf.h"
34
35namespace tensorflow {
36
37template <typename T>
38static Status ReadEntireFile(Env* env, const string& filename, T* contents) {
39 std::unique_ptr<RandomAccessFile> file;
40 TF_RETURN_IF_ERROR(env->NewRandomAccessFile(filename, &file));
41 io::RandomAccessInputStream input_stream(file.get());
42 io::BufferedInputStream in(&input_stream, 1 << 20);
43 TF_RETURN_IF_ERROR(in.ReadAll(contents));
44 return OkStatus();
45}
46
47class WholeFileReader : public ReaderBase {
48 public:
49 WholeFileReader(Env* env, const string& node_name)
50 : ReaderBase(strings::StrCat("WholeFileReader '", node_name, "'")),
51 env_(env) {}
52
53 Status ReadLocked(tstring* key, tstring* value, bool* produced,
54 bool* at_end) override {
55 *key = current_work();
56 TF_RETURN_IF_ERROR(ReadEntireFile(env_, *key, value));
57 *produced = true;
58 *at_end = true;
59 return OkStatus();
60 }
61
62 // Stores state in a ReaderBaseState proto, since WholeFileReader has
63 // no additional state beyond ReaderBase.
64 Status SerializeStateLocked(tstring* state) override {
65 ReaderBaseState base_state;
66 SaveBaseState(&base_state);
67 SerializeToTString(base_state, state);
68 return OkStatus();
69 }
70
71 Status RestoreStateLocked(const tstring& state) override {
72 ReaderBaseState base_state;
73 if (!ParseProtoUnlimited(&base_state, state)) {
74 return errors::InvalidArgument("Could not parse state for ", name(), ": ",
75 absl::CEscape(state));
76 }
77 TF_RETURN_IF_ERROR(RestoreBaseState(base_state));
78 return OkStatus();
79 }
80
81 private:
82 Env* env_;
83};
84
85class WholeFileReaderOp : public ReaderOpKernel {
86 public:
87 explicit WholeFileReaderOp(OpKernelConstruction* context)
88 : ReaderOpKernel(context) {
89 Env* env = context->env();
90 SetReaderFactory(
91 [this, env]() { return new WholeFileReader(env, name()); });
92 }
93};
94
95REGISTER_KERNEL_BUILDER(Name("WholeFileReader").Device(DEVICE_CPU),
96 WholeFileReaderOp);
97REGISTER_KERNEL_BUILDER(Name("WholeFileReaderV2").Device(DEVICE_CPU),
98 WholeFileReaderOp);
99
100class ReadFileOp : public OpKernel {
101 public:
102 using OpKernel::OpKernel;
103 void Compute(OpKernelContext* context) override {
104 const Tensor* input;
105 OP_REQUIRES_OK(context, context->input("filename", &input));
106 OP_REQUIRES(context, TensorShapeUtils::IsScalar(input->shape()),
107 errors::InvalidArgument(
108 "Input filename tensor must be scalar, but had shape: ",
109 input->shape().DebugString()));
110
111 Tensor* output = nullptr;
112 OP_REQUIRES_OK(context, context->allocate_output("contents",
113 TensorShape({}), &output));
114 OP_REQUIRES_OK(context,
115 ReadEntireFile(context->env(), input->scalar<tstring>()(),
116 &output->scalar<tstring>()()));
117 }
118};
119
120REGISTER_KERNEL_BUILDER(Name("ReadFile").Device(DEVICE_CPU), ReadFileOp);
121
122class WriteFileOp : public OpKernel {
123 public:
124 using OpKernel::OpKernel;
125 void Compute(OpKernelContext* context) override {
126 const Tensor* filename_input;
127 const Tensor* contents_input;
128 OP_REQUIRES_OK(context, context->input("filename", &filename_input));
129 OP_REQUIRES_OK(context, context->input("contents", &contents_input));
130 OP_REQUIRES(context, TensorShapeUtils::IsScalar(filename_input->shape()),
131 errors::InvalidArgument(
132 "Input filename tensor must be scalar, but had shape: ",
133 filename_input->shape().DebugString()));
134 OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_input->shape()),
135 errors::InvalidArgument(
136 "Contents tensor must be scalar, but had shape: ",
137 contents_input->shape().DebugString()));
138 const string& filename = filename_input->scalar<tstring>()();
139 const string dir(io::Dirname(filename));
140 if (!context->env()->FileExists(dir).ok()) {
141 OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir));
142 }
143 OP_REQUIRES_OK(context,
144 WriteStringToFile(context->env(), filename,
145 contents_input->scalar<tstring>()()));
146 }
147};
148
149REGISTER_KERNEL_BUILDER(Name("WriteFile").Device(DEVICE_CPU), WriteFileOp);
150} // namespace tensorflow
151