1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/io_ops.cc. |
17 | |
18 | #include "tensorflow/core/framework/op_kernel.h" |
19 | #include "tensorflow/core/framework/queue_interface.h" |
20 | #include "tensorflow/core/framework/reader_interface.h" |
21 | #include "tensorflow/core/framework/tensor_shape.h" |
22 | #include "tensorflow/core/kernels/ops_util.h" |
23 | #include "tensorflow/core/lib/core/threadpool.h" |
24 | #include "tensorflow/core/lib/strings/strcat.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | class ReaderVerbSyncOpKernel : public OpKernel { |
29 | public: |
30 | using OpKernel::OpKernel; |
31 | |
32 | void Compute(OpKernelContext* context) override { |
33 | ReaderInterface* reader; |
34 | OP_REQUIRES_OK(context, |
35 | GetResourceFromContext(context, "reader_handle" , &reader)); |
36 | ComputeWithReader(context, reader); |
37 | reader->Unref(); |
38 | } |
39 | |
40 | protected: |
41 | virtual void ComputeWithReader(OpKernelContext* context, |
42 | ReaderInterface* reader) = 0; |
43 | }; |
44 | |
45 | class ReaderVerbAsyncOpKernel : public AsyncOpKernel { |
46 | public: |
47 | using AsyncOpKernel::AsyncOpKernel; |
48 | |
49 | explicit ReaderVerbAsyncOpKernel(OpKernelConstruction* context) |
50 | : AsyncOpKernel(context), |
51 | thread_pool_(new thread::ThreadPool( |
52 | context->env(), ThreadOptions(), |
53 | strings::StrCat("reader_thread_" , SanitizeThreadSuffix(name())), |
54 | 1 /* num_threads */, false /* low_latency_hint */)) {} |
55 | |
56 | void ComputeAsync(OpKernelContext* context, DoneCallback done) override { |
57 | ReaderInterface* reader; |
58 | OP_REQUIRES_OK_ASYNC( |
59 | context, GetResourceFromContext(context, "reader_handle" , &reader), |
60 | done); |
61 | thread_pool_->Schedule([this, context, reader, done]() { |
62 | ComputeWithReader(context, reader); |
63 | reader->Unref(); |
64 | done(); |
65 | }); |
66 | } |
67 | |
68 | protected: |
69 | virtual void ComputeWithReader(OpKernelContext* context, |
70 | ReaderInterface* reader) = 0; |
71 | |
72 | private: |
73 | std::unique_ptr<thread::ThreadPool> thread_pool_; |
74 | }; |
75 | |
76 | class ReaderReadOp : public ReaderVerbAsyncOpKernel { |
77 | public: |
78 | using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel; |
79 | |
80 | void ComputeWithReader(OpKernelContext* context, |
81 | ReaderInterface* reader) override { |
82 | QueueInterface* queue; |
83 | OP_REQUIRES_OK(context, |
84 | GetResourceFromContext(context, "queue_handle" , &queue)); |
85 | core::ScopedUnref unref_me(queue); |
86 | Tensor* key = nullptr; |
87 | OP_REQUIRES_OK(context, |
88 | context->allocate_output("key" , TensorShape({}), &key)); |
89 | Tensor* value = nullptr; |
90 | OP_REQUIRES_OK(context, |
91 | context->allocate_output("value" , TensorShape({}), &value)); |
92 | |
93 | auto key_scalar = key->scalar<tstring>(); |
94 | auto value_scalar = value->scalar<tstring>(); |
95 | tstring key_out, val_out; |
96 | reader->Read(queue, &key_out, &val_out, context); |
97 | key_scalar() = key_out; |
98 | value_scalar() = val_out; |
99 | } |
100 | }; |
101 | |
102 | REGISTER_KERNEL_BUILDER(Name("ReaderRead" ).Device(DEVICE_CPU), ReaderReadOp); |
103 | REGISTER_KERNEL_BUILDER(Name("ReaderReadV2" ).Device(DEVICE_CPU), ReaderReadOp); |
104 | |
105 | class ReaderReadUpToOp : public ReaderVerbAsyncOpKernel { |
106 | public: |
107 | using ReaderVerbAsyncOpKernel::ReaderVerbAsyncOpKernel; |
108 | |
109 | void ComputeWithReader(OpKernelContext* context, |
110 | ReaderInterface* reader) override { |
111 | QueueInterface* queue; |
112 | |
113 | const Tensor* num_records_tensor; |
114 | OP_REQUIRES_OK(context, context->input("num_records" , &num_records_tensor)); |
115 | int64_t num_records = num_records_tensor->scalar<int64_t>()(); |
116 | |
117 | OP_REQUIRES_OK(context, |
118 | GetResourceFromContext(context, "queue_handle" , &queue)); |
119 | core::ScopedUnref unref_me(queue); |
120 | |
121 | std::vector<tstring> keys_vec; |
122 | keys_vec.reserve(num_records); |
123 | std::vector<tstring> values_vec; |
124 | values_vec.reserve(num_records); |
125 | |
126 | int64_t num_actually_read = |
127 | reader->ReadUpTo(num_records, queue, &keys_vec, &values_vec, context); |
128 | |
129 | OP_REQUIRES(context, num_actually_read == keys_vec.size(), |
130 | errors::InvalidArgument("num_actually_read != len(keys_vec" )); |
131 | |
132 | OP_REQUIRES(context, num_actually_read == values_vec.size(), |
133 | errors::InvalidArgument("num_actually_read != len(values_vec" )); |
134 | |
135 | Tensor* keys = nullptr; |
136 | OP_REQUIRES_OK(context, |
137 | context->allocate_output( |
138 | "keys" , TensorShape({num_actually_read}), &keys)); |
139 | |
140 | Tensor* values = nullptr; |
141 | OP_REQUIRES_OK(context, |
142 | context->allocate_output( |
143 | "values" , TensorShape({num_actually_read}), &values)); |
144 | |
145 | auto keys_t = keys->vec<tstring>(); |
146 | auto values_t = values->vec<tstring>(); |
147 | for (int i = 0; i < num_actually_read; ++i) { |
148 | keys_t(i) = std::move(keys_vec[i]); |
149 | values_t(i) = std::move(values_vec[i]); |
150 | } |
151 | } |
152 | }; |
153 | |
154 | REGISTER_KERNEL_BUILDER(Name("ReaderReadUpTo" ).Device(DEVICE_CPU), |
155 | ReaderReadUpToOp); |
156 | REGISTER_KERNEL_BUILDER(Name("ReaderReadUpToV2" ).Device(DEVICE_CPU), |
157 | ReaderReadUpToOp); |
158 | |
159 | class ReaderNumRecordsProducedOp : public ReaderVerbSyncOpKernel { |
160 | public: |
161 | using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; |
162 | |
163 | void ComputeWithReader(OpKernelContext* context, |
164 | ReaderInterface* reader) override { |
165 | Tensor* output = nullptr; |
166 | OP_REQUIRES_OK(context, context->allocate_output("records_produced" , |
167 | TensorShape({}), &output)); |
168 | output->scalar<int64_t>()() = reader->NumRecordsProduced(); |
169 | } |
170 | }; |
171 | |
172 | REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProduced" ).Device(DEVICE_CPU), |
173 | ReaderNumRecordsProducedOp); |
174 | REGISTER_KERNEL_BUILDER(Name("ReaderNumRecordsProducedV2" ).Device(DEVICE_CPU), |
175 | ReaderNumRecordsProducedOp); |
176 | |
177 | class ReaderNumWorkUnitsCompletedOp : public ReaderVerbSyncOpKernel { |
178 | public: |
179 | using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; |
180 | |
181 | void ComputeWithReader(OpKernelContext* context, |
182 | ReaderInterface* reader) override { |
183 | Tensor* output = nullptr; |
184 | OP_REQUIRES_OK(context, context->allocate_output("units_completed" , |
185 | TensorShape({}), &output)); |
186 | output->scalar<int64_t>()() = reader->NumWorkUnitsCompleted(); |
187 | } |
188 | }; |
189 | |
190 | REGISTER_KERNEL_BUILDER(Name("ReaderNumWorkUnitsCompleted" ).Device(DEVICE_CPU), |
191 | ReaderNumWorkUnitsCompletedOp); |
192 | REGISTER_KERNEL_BUILDER( |
193 | Name("ReaderNumWorkUnitsCompletedV2" ).Device(DEVICE_CPU), |
194 | ReaderNumWorkUnitsCompletedOp); |
195 | |
196 | class ReaderSerializeStateOp : public ReaderVerbSyncOpKernel { |
197 | public: |
198 | using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; |
199 | |
200 | void ComputeWithReader(OpKernelContext* context, |
201 | ReaderInterface* reader) override { |
202 | Tensor* output = nullptr; |
203 | OP_REQUIRES_OK(context, |
204 | context->allocate_output("state" , TensorShape({}), &output)); |
205 | OP_REQUIRES_OK(context, |
206 | reader->SerializeState(&output->scalar<tstring>()())); |
207 | } |
208 | }; |
209 | |
210 | REGISTER_KERNEL_BUILDER(Name("ReaderSerializeState" ).Device(DEVICE_CPU), |
211 | ReaderSerializeStateOp); |
212 | REGISTER_KERNEL_BUILDER(Name("ReaderSerializeStateV2" ).Device(DEVICE_CPU), |
213 | ReaderSerializeStateOp); |
214 | |
215 | class ReaderRestoreStateOp : public ReaderVerbSyncOpKernel { |
216 | public: |
217 | using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; |
218 | |
219 | void ComputeWithReader(OpKernelContext* context, |
220 | ReaderInterface* reader) override { |
221 | const Tensor* tensor; |
222 | OP_REQUIRES_OK(context, context->input("state" , &tensor)); |
223 | OP_REQUIRES( |
224 | context, TensorShapeUtils::IsScalar(tensor->shape()), |
225 | errors::InvalidArgument("Reader state must be scalar, but had shape: " , |
226 | tensor->shape().DebugString())); |
227 | OP_REQUIRES_OK(context, reader->RestoreState(tensor->scalar<tstring>()())); |
228 | } |
229 | }; |
230 | |
231 | REGISTER_KERNEL_BUILDER(Name("ReaderRestoreState" ).Device(DEVICE_CPU), |
232 | ReaderRestoreStateOp); |
233 | REGISTER_KERNEL_BUILDER(Name("ReaderRestoreStateV2" ).Device(DEVICE_CPU), |
234 | ReaderRestoreStateOp); |
235 | |
236 | class ReaderResetOp : public ReaderVerbSyncOpKernel { |
237 | public: |
238 | using ReaderVerbSyncOpKernel::ReaderVerbSyncOpKernel; |
239 | |
240 | void ComputeWithReader(OpKernelContext* context, |
241 | ReaderInterface* reader) override { |
242 | OP_REQUIRES_OK(context, reader->Reset()); |
243 | } |
244 | }; |
245 | |
246 | REGISTER_KERNEL_BUILDER(Name("ReaderReset" ).Device(DEVICE_CPU), ReaderResetOp); |
247 | REGISTER_KERNEL_BUILDER(Name("ReaderResetV2" ).Device(DEVICE_CPU), |
248 | ReaderResetOp); |
249 | |
250 | } // namespace tensorflow |
251 | |