1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <cstddef> |
17 | #include <deque> |
18 | #include <mutex> |
19 | #include <numeric> |
20 | #include <vector> |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/resource_mgr.h" |
24 | #include "tensorflow/core/framework/tensor.h" |
25 | #include "tensorflow/core/framework/tensor_shape.h" |
26 | #include "tensorflow/core/lib/strings/strcat.h" |
27 | #include "tensorflow/core/platform/env.h" |
28 | #include "tensorflow/core/platform/mutex.h" |
29 | |
30 | namespace tensorflow { |
31 | namespace { |
32 | |
33 | class Buffer : public ResourceBase { |
34 | public: |
35 | using Tuple = std::vector<Tensor>; |
36 | |
37 | explicit Buffer(std::size_t capacity, std::size_t memory_limit) |
38 | : capacity_(capacity), memory_limit_(memory_limit), current_bytes_(0) {} |
39 | |
40 | // the Buffer takes ownership of the Tuple |
41 | Status Put(Tuple* tuple) { |
42 | std::unique_lock<std::mutex> lock(mu_); |
43 | |
44 | std::size_t tuple_bytes = GetTupleBytes(*tuple); |
45 | |
46 | // Sanity check so that we don't block for ever below |
47 | if (memory_limit_ > 0 && tuple_bytes > memory_limit_) { |
48 | return Status( |
49 | errors::ResourceExhausted("Attempted to insert " |
50 | "tensors with combined size of '" , |
51 | tuple_bytes, |
52 | "' bytes into " |
53 | "Staging Area with a memory limit of '" , |
54 | memory_limit_, "'." )); |
55 | } |
56 | |
57 | // If buffer capacity is bounded wait until elements have been removed |
58 | if (IsBounded()) { |
59 | full_cond_var_.wait(lock, [tuple_bytes, this]() { |
60 | // If there's a memory limit, check if there's space for insertion |
61 | bool memory_limit_valid = |
62 | memory_limit_ > 0 ? !WouldExceedMemoryLimit(tuple_bytes) : true; |
63 | // If we're configured for capacity check if there's space for insertion |
64 | bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true; |
65 | |
66 | // Stop waiting upon success for both conditions |
67 | return capacity_valid && memory_limit_valid; |
68 | }); |
69 | } |
70 | |
71 | // Update bytes in the Staging Area |
72 | current_bytes_ += tuple_bytes; |
73 | |
74 | // Store tuple |
75 | buf_.push_back(std::move(*tuple)); |
76 | |
77 | lock.unlock(); |
78 | // Notify all removers. Removers |
79 | // may be peeking at a specific element or waiting |
80 | // for the element at the front of the deque. |
81 | // As we don't know the appropriate one to wake up |
82 | // we should wake them all. |
83 | non_empty_cond_var_.notify_all(); |
84 | |
85 | return OkStatus(); |
86 | } |
87 | |
88 | // Get tuple at front of the buffer |
89 | void Get(Tuple* tuple) { // TODO(zhifengc): Support cancellation. |
90 | std::unique_lock<std::mutex> lock(mu_); |
91 | |
92 | // Wait for data if the buffer is empty |
93 | non_empty_cond_var_.wait(lock, [this]() { return !buf_.empty(); }); |
94 | |
95 | // Move data into the output tuple |
96 | *tuple = std::move(buf_.front()); |
97 | buf_.pop_front(); |
98 | |
99 | // Update bytes in the Staging Area |
100 | current_bytes_ -= GetTupleBytes(*tuple); |
101 | |
102 | notify_inserters_if_bounded(&lock); |
103 | } |
104 | |
105 | // Return tuple at index |
106 | Status Peek(std::size_t index, Tuple* tuple) { |
107 | std::unique_lock<std::mutex> lock(mu_); |
108 | |
109 | // Wait if the requested index is not available |
110 | non_empty_cond_var_.wait( |
111 | lock, [index, this]() { return index < this->buf_.size(); }); |
112 | |
113 | // Place tensors in the output tuple |
114 | for (const auto& tensor : buf_[index]) { |
115 | tuple->push_back(tensor); |
116 | } |
117 | |
118 | return OkStatus(); |
119 | } |
120 | |
121 | // Buffer size |
122 | size_t Size() { |
123 | std::unique_lock<std::mutex> lock(mu_); |
124 | return buf_.size(); |
125 | } |
126 | |
127 | void Clear() { |
128 | std::unique_lock<std::mutex> lock(mu_); |
129 | buf_.clear(); |
130 | current_bytes_ = 0; |
131 | |
132 | notify_inserters_if_bounded(&lock); |
133 | } |
134 | |
135 | string DebugString() const override { |
136 | std::unique_lock<std::mutex> lock(mu_); |
137 | return strings::StrCat("Staging size: " , buf_.size()); |
138 | } |
139 | |
140 | private: |
141 | // If the buffer is configured for bounded capacity, notify |
142 | // waiting inserters that space is now available |
143 | void notify_inserters_if_bounded(std::unique_lock<std::mutex>* lock) { |
144 | if (IsBounded()) { |
145 | lock->unlock(); |
146 | // Notify all inserters. The removal of an element |
147 | // may make memory available for many inserters |
148 | // to insert new elements |
149 | full_cond_var_.notify_all(); |
150 | } |
151 | } |
152 | |
153 | // Are there a limit number of elements or a memory limit |
154 | // configured on this buffer? |
155 | bool IsBounded() const { return capacity_ > 0 || memory_limit_ > 0; } |
156 | |
157 | bool IsCapacityFull() const { return buf_.size() >= capacity_; } |
158 | |
159 | bool WouldExceedMemoryLimit(std::size_t bytes) const { |
160 | return bytes + current_bytes_ > memory_limit_; |
161 | } |
162 | |
163 | std::size_t GetTupleBytes(const Tuple& tuple) { |
164 | return std::accumulate(tuple.begin(), tuple.end(), 0, |
165 | [](const std::size_t& lhs, const Tensor& rhs) { |
166 | return lhs + rhs.TotalBytes(); |
167 | }); |
168 | } |
169 | |
170 | std::size_t capacity_; |
171 | std::size_t memory_limit_; |
172 | std::size_t current_bytes_; |
173 | mutable std::mutex mu_; |
174 | std::condition_variable non_empty_cond_var_; |
175 | std::condition_variable full_cond_var_; |
176 | std::deque<Tuple> buf_; |
177 | }; |
178 | |
179 | Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) { |
180 | auto rm = ctx->resource_manager(); |
181 | ContainerInfo cinfo; |
182 | |
183 | // Lambda for creating the Staging Area |
184 | auto create_fn = [&ndef](Buffer** ret) -> Status { |
185 | int64_t capacity; |
186 | int64_t memory_limit; |
187 | TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity" , &capacity)); |
188 | TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit" , &memory_limit)); |
189 | *ret = new Buffer(capacity, memory_limit); |
190 | return OkStatus(); |
191 | }; |
192 | |
193 | TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */)); |
194 | TF_RETURN_IF_ERROR(rm->LookupOrCreate<Buffer>(cinfo.container(), cinfo.name(), |
195 | buf, create_fn)); |
196 | return OkStatus(); |
197 | } |
198 | |
199 | } // namespace |
200 | |
201 | class StageOp : public OpKernel { |
202 | public: |
203 | explicit StageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
204 | |
205 | void Compute(OpKernelContext* ctx) override { |
206 | Buffer* buf = nullptr; |
207 | OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf)); |
208 | core::ScopedUnref scope(buf); |
209 | Buffer::Tuple tuple; |
210 | tuple.reserve(ctx->num_inputs()); |
211 | for (int i = 0; i < ctx->num_inputs(); ++i) { |
212 | tuple.push_back(ctx->input(i)); |
213 | } |
214 | OP_REQUIRES_OK(ctx, buf->Put(&tuple)); |
215 | } |
216 | }; |
217 | |
218 | REGISTER_KERNEL_BUILDER(Name("Stage" ).Device(DEVICE_CPU), StageOp); |
219 | REGISTER_KERNEL_BUILDER(Name("Stage" ).Device(DEVICE_DEFAULT), StageOp); |
220 | |
221 | class UnstageOp : public OpKernel { |
222 | public: |
223 | explicit UnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
224 | |
225 | // Using this op in such a way that it blocks forever |
226 | // is an error. As such cancellation is not handled. |
227 | void Compute(OpKernelContext* ctx) override { |
228 | Buffer* buf = nullptr; |
229 | OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf)); |
230 | core::ScopedUnref scope(buf); |
231 | Buffer::Tuple tuple; |
232 | |
233 | buf->Get(&tuple); |
234 | |
235 | OP_REQUIRES( |
236 | ctx, tuple.size() == (size_t)ctx->num_outputs(), |
237 | errors::InvalidArgument("Mismatch stage/unstage: " , tuple.size(), |
238 | " vs. " , ctx->num_outputs())); |
239 | |
240 | for (size_t i = 0; i < tuple.size(); ++i) { |
241 | ctx->set_output(i, tuple[i]); |
242 | } |
243 | } |
244 | }; |
245 | |
246 | REGISTER_KERNEL_BUILDER(Name("Unstage" ).Device(DEVICE_CPU), UnstageOp); |
247 | REGISTER_KERNEL_BUILDER(Name("Unstage" ).Device(DEVICE_DEFAULT), UnstageOp); |
248 | |
249 | class StagePeekOp : public OpKernel { |
250 | public: |
251 | explicit StagePeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
252 | |
253 | // Using this op in such a way that it blocks forever |
254 | // is an error. As such cancellation is not handled. |
255 | void Compute(OpKernelContext* ctx) override { |
256 | Buffer* buf = nullptr; |
257 | OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf)); |
258 | core::ScopedUnref scope(buf); |
259 | Buffer::Tuple tuple; |
260 | |
261 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(0).shape()), |
262 | errors::InvalidArgument("index must be scalar" )); |
263 | std::size_t index = ctx->input(0).scalar<int>()(); |
264 | |
265 | OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple)); |
266 | |
267 | OP_REQUIRES( |
268 | ctx, tuple.size() == (size_t)ctx->num_outputs(), |
269 | errors::InvalidArgument("Mismatch stage/unstage: " , tuple.size(), |
270 | " vs. " , ctx->num_outputs())); |
271 | |
272 | for (size_t i = 0; i < tuple.size(); ++i) { |
273 | ctx->set_output(i, tuple[i]); |
274 | } |
275 | } |
276 | }; |
277 | |
278 | REGISTER_KERNEL_BUILDER(Name("StagePeek" ).Device(DEVICE_CPU), StagePeekOp); |
279 | REGISTER_KERNEL_BUILDER( |
280 | Name("StagePeek" ).HostMemory("index" ).Device(DEVICE_DEFAULT), StagePeekOp); |
281 | |
282 | class StageSizeOp : public OpKernel { |
283 | public: |
284 | explicit StageSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
285 | |
286 | // Using this op in such a way that it blocks forever |
287 | // is an error. As such cancellation is not handled. |
288 | void Compute(OpKernelContext* ctx) override { |
289 | Buffer* buf = nullptr; |
290 | OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf)); |
291 | core::ScopedUnref scope(buf); |
292 | |
293 | // Allocate size output tensor |
294 | Tensor* size = nullptr; |
295 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size)); |
296 | |
297 | // Set it to the actual size |
298 | size->scalar<int32>().setConstant(buf->Size()); |
299 | } |
300 | }; |
301 | |
302 | REGISTER_KERNEL_BUILDER(Name("StageSize" ).Device(DEVICE_CPU), StageSizeOp); |
303 | REGISTER_KERNEL_BUILDER( |
304 | Name("StageSize" ).HostMemory("size" ).Device(DEVICE_DEFAULT), StageSizeOp); |
305 | |
306 | class StageClearOp : public OpKernel { |
307 | public: |
308 | explicit StageClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
309 | |
310 | // Using this op in such a way that it blocks forever |
311 | // is an error. As such cancellation is not handled. |
312 | void Compute(OpKernelContext* ctx) override { |
313 | Buffer* buf = nullptr; |
314 | OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf)); |
315 | core::ScopedUnref scope(buf); |
316 | |
317 | buf->Clear(); |
318 | } |
319 | }; |
320 | |
321 | REGISTER_KERNEL_BUILDER(Name("StageClear" ).Device(DEVICE_CPU), StageClearOp); |
322 | REGISTER_KERNEL_BUILDER(Name("StageClear" ).Device(DEVICE_DEFAULT), |
323 | StageClearOp); |
324 | |
325 | } // namespace tensorflow |
326 | |