1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/io_ops.cc.
17
18#include "tensorflow/core/framework/op_kernel.h"
19#include "tensorflow/core/framework/queue_interface.h"
20#include "tensorflow/core/framework/reader_interface.h"
21#include "tensorflow/core/framework/tensor_shape.h"
22#include "tensorflow/core/kernels/ops_util.h"
23#include "tensorflow/core/lib/core/threadpool.h"
24#include "tensorflow/core/lib/strings/strcat.h"
25
26namespace tensorflow {
27
28class ReaderVerbSyncOpKernel : public OpKernel {
29 public:
30 using OpKernel::OpKernel;
31
32 void Compute(OpKernelContext* context) override {
33 ReaderInterface* reader;
34 OP_REQUIRES_OK(context,
35 GetResourceFromContext(context, "reader_handle", &reader));
36 ComputeWithReader(context, reader);
37 reader->Unref();
38 }
39
40 protected:
41 virtual void ComputeWithReader(OpKernelContext* context,
42 ReaderInterface* reader) = 0;
43};
44
45class ReaderVerbAsyncOpKernel : public AsyncOpKernel {
46 public:
47 using AsyncOpKernel::AsyncOpKernel;
48
49 explicit ReaderVerbAsyncOpKernel(OpKernelConstruction* context)
50 : AsyncOpKernel(context),
51 thread_pool_(new thread::ThreadPool(
52 context->env(), ThreadOptions(),
53 strings::StrCat("reader_thread_", SanitizeThreadSuffix(name())),
54 1 /* num_threads */, false /* low_latency_hint */)) {}
55
56 void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
57 ReaderInterface* reader;
58 OP_REQUIRES_OK_ASYNC(
59 context, GetResourceFromContext(context, "reader_handle", &reader),
60 done);
61 thread_pool_->Schedule([this, context, reader, done]() {
62 ComputeWithReader(context, reader);
63 reader->Unref();
64 done();
65 });
66 }
67
68 protected:
69 virtual void ComputeWithReader(OpKernelContext* context,
70 ReaderInterface* reader) = 0;
71
72 private:
73 std::unique_ptr<thread::ThreadPool> thread_pool_;
74};
75
76class ReaderReadOp : public ReaderVerbAsyncOpKernel {
77 public:
78 using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel;
79
80 void ComputeWithReader(OpKernelContext* context,
81 ReaderInterface* reader) override {
82 QueueInterface* queue;
83 OP_REQUIRES_OK(context,
84 GetResourceFromContext(context, "queue_handle", &queue));
85 core::ScopedUnref unref_me(queue);
86 Tensor* key = nullptr;
87 OP_REQUIRES_OK(context,
88 context->allocate_output("key", TensorShape({}), &key));
89 Tensor* value = nullptr;
90 OP_REQUIRES_OK(context,
91 context->allocate_output("value", TensorShape({}), &value));
92
93 auto key_scalar = key->scalar<tstring>();
94 auto value_scalar = value->scalar<tstring>();
95 tstring key_out, val_out;
96 reader->Read(queue, &key_out, &val_out, context);
97 key_scalar() = key_out;
98 value_scalar() = val_out;
99 }
100};
101
102REGISTER_KERNEL_BUILDER(Name("ReaderRead").Device(DEVICE_CPU), ReaderReadOp);
103REGISTER_KERNEL_BUILDER(Name("ReaderReadV2").Device(DEVICE_CPU), ReaderReadOp);
104
105class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel {
106 public:
107 using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel;
108
109 void ComputeWithReader(OpKernelContext* context,
110 ReaderInterface* reader) override {
111 QueueInterface* queue;
112
113 const Tensor* num_records_tensor;
114 OP_REQUIRES_OK(context, context->input("num_records", &num_records_tensor));
115 int64_t num_records = num_records_tensor->scalar<int64_t>()();
116
117 OP_REQUIRES_OK(context,
118 GetResourceFromContext(context, "queue_handle", &queue));
119 core::ScopedUnref unref_me(queue);
120
121 std::vector<tstring> keys_vec;
122 keys_vec.reserve(num_records);
123 std::vector<tstring> values_vec;
124 values_vec.reserve(num_records);
125
126 int64_t num_actually_read =
127 reader->ReadUpTo(num_records, queue, &keys_vec, &values_vec, context);
128
129 OP_REQUIRES(context, num_actually_read == keys_vec.size(),
130 errors::InvalidArgument("num_actually_read != len(keys_vec"));
131
132 OP_REQUIRES(context, num_actually_read == values_vec.size(),
133 errors::InvalidArgument("num_actually_read != len(values_vec"));
134
135 Tensor* keys = nullptr;
136 OP_REQUIRES_OK(context,
137 context->allocate_output(
138 "keys", TensorShape({num_actually_read}), &keys));
139
140 Tensor* values = nullptr;
141 OP_REQUIRES_OK(context,
142 context->allocate_output(
143 "values", TensorShape({num_actually_read}), &values));
144
145 auto keys_t = keys->vec<tstring>();
146 auto values_t = values->vec<tstring>();
147 for (int i = 0; i < num_actually_read; ++i) {
148 keys_t(i) = std::move(keys_vec[i]);
149 values_t(i) = std::move(values_vec[i]);
150 }
151 }
152};
153
154REGISTER_KERNEL_BUILDER(Name("ReaderReadUpTo").Device(DEVICE_CPU),
155 ReaderReadUpToOp);
156REGISTER_KERNEL_BUILDER(Name("ReaderReadUpToV2").Device(DEVICE_CPU),
157 ReaderReadUpToOp);
158
159class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel {
160 public:
161 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
162
163 void ComputeWithReader(OpKernelContext* context,
164 ReaderInterface* reader) override {
165 Tensor* output = nullptr;
166 OP_REQUIRES_OK(context, context->allocate_output("records_produced",
167 TensorShape({}), &output));
168 output->scalar<int64_t>()() = reader->NumRecordsProduced();
169 }
170};
171
172REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced").Device(DEVICE_CPU),
173 ReaderNumRecordsProducedOp);
174REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProducedV2").Device(DEVICE_CPU),
175 ReaderNumRecordsProducedOp);
176
177class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel {
178 public:
179 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
180
181 void ComputeWithReader(OpKernelContext* context,
182 ReaderInterface* reader) override {
183 Tensor* output = nullptr;
184 OP_REQUIRES_OK(context, context->allocate_output("units_completed",
185 TensorShape({}), &output));
186 output->scalar<int64_t>()() = reader->NumWorkUnitsCompleted();
187 }
188};
189
190REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted").Device(DEVICE_CPU),
191 ReaderNumWorkUnitsCompletedOp);
192REGISTER_KERNEL_BUILDER(
193 Name("ReaderNumWorkUnitsCompletedV2").Device(DEVICE_CPU),
194 ReaderNumWorkUnitsCompletedOp);
195
196class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel {
197 public:
198 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
199
200 void ComputeWithReader(OpKernelContext* context,
201 ReaderInterface* reader) override {
202 Tensor* output = nullptr;
203 OP_REQUIRES_OK(context,
204 context->allocate_output("state", TensorShape({}), &output));
205 OP_REQUIRES_OK(context,
206 reader->SerializeState(&output->scalar<tstring>()()));
207 }
208};
209
210REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState").Device(DEVICE_CPU),
211 ReaderSerializeStateOp);
212REGISTER_KERNEL_BUILDER(Name("ReaderSerializeStateV2").Device(DEVICE_CPU),
213 ReaderSerializeStateOp);
214
215class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel {
216 public:
217 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
218
219 void ComputeWithReader(OpKernelContext* context,
220 ReaderInterface* reader) override {
221 const Tensor* tensor;
222 OP_REQUIRES_OK(context, context->input("state", &tensor));
223 OP_REQUIRES(
224 context, TensorShapeUtils::IsScalar(tensor->shape()),
225 errors::InvalidArgument("Reader state must be scalar, but had shape: ",
226 tensor->shape().DebugString()));
227 OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar<tstring>()()));
228 }
229};
230
231REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState").Device(DEVICE_CPU),
232 ReaderRestoreStateOp);
233REGISTER_KERNEL_BUILDER(Name("ReaderRestoreStateV2").Device(DEVICE_CPU),
234 ReaderRestoreStateOp);
235
236class ReaderResetOp : public ReaderVerbSyncOpKernel {
237 public:
238 using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel;
239
240 void ComputeWithReader(OpKernelContext* context,
241 ReaderInterface* reader) override {
242 OP_REQUIRES_OK(context, reader->Reset());
243 }
244};
245
246REGISTER_KERNEL_BUILDER(Name("ReaderReset").Device(DEVICE_CPU), ReaderResetOp);
247REGISTER_KERNEL_BUILDER(Name("ReaderResetV2").Device(DEVICE_CPU),
248 ReaderResetOp);
249
250} // namespace tensorflow
251