1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstddef>
17#include <deque>
18#include <mutex>
19#include <numeric>
20#include <vector>
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/resource_mgr.h"
24#include "tensorflow/core/framework/tensor.h"
25#include "tensorflow/core/framework/tensor_shape.h"
26#include "tensorflow/core/lib/strings/strcat.h"
27#include "tensorflow/core/platform/env.h"
28#include "tensorflow/core/platform/mutex.h"
29
30namespace tensorflow {
31namespace {
32
33class Buffer : public ResourceBase {
34 public:
35 using Tuple = std::vector<Tensor>;
36
37 explicit Buffer(std::size_t capacity, std::size_t memory_limit)
38 : capacity_(capacity), memory_limit_(memory_limit), current_bytes_(0) {}
39
40 // the Buffer takes ownership of the Tuple
41 Status Put(Tuple* tuple) {
42 std::unique_lock<std::mutex> lock(mu_);
43
44 std::size_t tuple_bytes = GetTupleBytes(*tuple);
45
46 // Sanity check so that we don't block for ever below
47 if (memory_limit_ > 0 && tuple_bytes > memory_limit_) {
48 return Status(
49 errors::ResourceExhausted("Attempted to insert "
50 "tensors with combined size of '",
51 tuple_bytes,
52 "' bytes into "
53 "Staging Area with a memory limit of '",
54 memory_limit_, "'."));
55 }
56
57 // If buffer capacity is bounded wait until elements have been removed
58 if (IsBounded()) {
59 full_cond_var_.wait(lock, [tuple_bytes, this]() {
60 // If there's a memory limit, check if there's space for insertion
61 bool memory_limit_valid =
62 memory_limit_ > 0 ? !WouldExceedMemoryLimit(tuple_bytes) : true;
63 // If we're configured for capacity check if there's space for insertion
64 bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true;
65
66 // Stop waiting upon success for both conditions
67 return capacity_valid && memory_limit_valid;
68 });
69 }
70
71 // Update bytes in the Staging Area
72 current_bytes_ += tuple_bytes;
73
74 // Store tuple
75 buf_.push_back(std::move(*tuple));
76
77 lock.unlock();
78 // Notify all removers. Removers
79 // may be peeking at a specific element or waiting
80 // for the element at the front of the deque.
81 // As we don't know the appropriate one to wake up
82 // we should wake them all.
83 non_empty_cond_var_.notify_all();
84
85 return OkStatus();
86 }
87
88 // Get tuple at front of the buffer
89 void Get(Tuple* tuple) { // TODO(zhifengc): Support cancellation.
90 std::unique_lock<std::mutex> lock(mu_);
91
92 // Wait for data if the buffer is empty
93 non_empty_cond_var_.wait(lock, [this]() { return !buf_.empty(); });
94
95 // Move data into the output tuple
96 *tuple = std::move(buf_.front());
97 buf_.pop_front();
98
99 // Update bytes in the Staging Area
100 current_bytes_ -= GetTupleBytes(*tuple);
101
102 notify_inserters_if_bounded(&lock);
103 }
104
105 // Return tuple at index
106 Status Peek(std::size_t index, Tuple* tuple) {
107 std::unique_lock<std::mutex> lock(mu_);
108
109 // Wait if the requested index is not available
110 non_empty_cond_var_.wait(
111 lock, [index, this]() { return index < this->buf_.size(); });
112
113 // Place tensors in the output tuple
114 for (const auto& tensor : buf_[index]) {
115 tuple->push_back(tensor);
116 }
117
118 return OkStatus();
119 }
120
121 // Buffer size
122 size_t Size() {
123 std::unique_lock<std::mutex> lock(mu_);
124 return buf_.size();
125 }
126
127 void Clear() {
128 std::unique_lock<std::mutex> lock(mu_);
129 buf_.clear();
130 current_bytes_ = 0;
131
132 notify_inserters_if_bounded(&lock);
133 }
134
135 string DebugString() const override {
136 std::unique_lock<std::mutex> lock(mu_);
137 return strings::StrCat("Staging size: ", buf_.size());
138 }
139
140 private:
141 // If the buffer is configured for bounded capacity, notify
142 // waiting inserters that space is now available
143 void notify_inserters_if_bounded(std::unique_lock<std::mutex>* lock) {
144 if (IsBounded()) {
145 lock->unlock();
146 // Notify all inserters. The removal of an element
147 // may make memory available for many inserters
148 // to insert new elements
149 full_cond_var_.notify_all();
150 }
151 }
152
153 // Are there a limit number of elements or a memory limit
154 // configured on this buffer?
155 bool IsBounded() const { return capacity_ > 0 || memory_limit_ > 0; }
156
157 bool IsCapacityFull() const { return buf_.size() >= capacity_; }
158
159 bool WouldExceedMemoryLimit(std::size_t bytes) const {
160 return bytes + current_bytes_ > memory_limit_;
161 }
162
163 std::size_t GetTupleBytes(const Tuple& tuple) {
164 return std::accumulate(tuple.begin(), tuple.end(), 0,
165 [](const std::size_t& lhs, const Tensor& rhs) {
166 return lhs + rhs.TotalBytes();
167 });
168 }
169
170 std::size_t capacity_;
171 std::size_t memory_limit_;
172 std::size_t current_bytes_;
173 mutable std::mutex mu_;
174 std::condition_variable non_empty_cond_var_;
175 std::condition_variable full_cond_var_;
176 std::deque<Tuple> buf_;
177};
178
179Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) {
180 auto rm = ctx->resource_manager();
181 ContainerInfo cinfo;
182
183 // Lambda for creating the Staging Area
184 auto create_fn = [&ndef](Buffer** ret) -> Status {
185 int64_t capacity;
186 int64_t memory_limit;
187 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
188 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
189 *ret = new Buffer(capacity, memory_limit);
190 return OkStatus();
191 };
192
193 TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
194 TF_RETURN_IF_ERROR(rm->LookupOrCreate<Buffer>(cinfo.container(), cinfo.name(),
195 buf, create_fn));
196 return OkStatus();
197}
198
199} // namespace
200
201class StageOp : public OpKernel {
202 public:
203 explicit StageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
204
205 void Compute(OpKernelContext* ctx) override {
206 Buffer* buf = nullptr;
207 OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
208 core::ScopedUnref scope(buf);
209 Buffer::Tuple tuple;
210 tuple.reserve(ctx->num_inputs());
211 for (int i = 0; i < ctx->num_inputs(); ++i) {
212 tuple.push_back(ctx->input(i));
213 }
214 OP_REQUIRES_OK(ctx, buf->Put(&tuple));
215 }
216};
217
218REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp);
219REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_DEFAULT), StageOp);
220
221class UnstageOp : public OpKernel {
222 public:
223 explicit UnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
224
225 // Using this op in such a way that it blocks forever
226 // is an error. As such cancellation is not handled.
227 void Compute(OpKernelContext* ctx) override {
228 Buffer* buf = nullptr;
229 OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
230 core::ScopedUnref scope(buf);
231 Buffer::Tuple tuple;
232
233 buf->Get(&tuple);
234
235 OP_REQUIRES(
236 ctx, tuple.size() == (size_t)ctx->num_outputs(),
237 errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
238 " vs. ", ctx->num_outputs()));
239
240 for (size_t i = 0; i < tuple.size(); ++i) {
241 ctx->set_output(i, tuple[i]);
242 }
243 }
244};
245
246REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp);
247REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_DEFAULT), UnstageOp);
248
249class StagePeekOp : public OpKernel {
250 public:
251 explicit StagePeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
252
253 // Using this op in such a way that it blocks forever
254 // is an error. As such cancellation is not handled.
255 void Compute(OpKernelContext* ctx) override {
256 Buffer* buf = nullptr;
257 OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
258 core::ScopedUnref scope(buf);
259 Buffer::Tuple tuple;
260
261 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(ctx->input(0).shape()),
262 errors::InvalidArgument("index must be scalar"));
263 std::size_t index = ctx->input(0).scalar<int>()();
264
265 OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple));
266
267 OP_REQUIRES(
268 ctx, tuple.size() == (size_t)ctx->num_outputs(),
269 errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
270 " vs. ", ctx->num_outputs()));
271
272 for (size_t i = 0; i < tuple.size(); ++i) {
273 ctx->set_output(i, tuple[i]);
274 }
275 }
276};
277
278REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), StagePeekOp);
279REGISTER_KERNEL_BUILDER(
280 Name("StagePeek").HostMemory("index").Device(DEVICE_DEFAULT), StagePeekOp);
281
282class StageSizeOp : public OpKernel {
283 public:
284 explicit StageSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
285
286 // Using this op in such a way that it blocks forever
287 // is an error. As such cancellation is not handled.
288 void Compute(OpKernelContext* ctx) override {
289 Buffer* buf = nullptr;
290 OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
291 core::ScopedUnref scope(buf);
292
293 // Allocate size output tensor
294 Tensor* size = nullptr;
295 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
296
297 // Set it to the actual size
298 size->scalar<int32>().setConstant(buf->Size());
299 }
300};
301
302REGISTER_KERNEL_BUILDER(Name("StageSize").Device(DEVICE_CPU), StageSizeOp);
303REGISTER_KERNEL_BUILDER(
304 Name("StageSize").HostMemory("size").Device(DEVICE_DEFAULT), StageSizeOp);
305
306class StageClearOp : public OpKernel {
307 public:
308 explicit StageClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
309
310 // Using this op in such a way that it blocks forever
311 // is an error. As such cancellation is not handled.
312 void Compute(OpKernelContext* ctx) override {
313 Buffer* buf = nullptr;
314 OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
315 core::ScopedUnref scope(buf);
316
317 buf->Clear();
318 }
319};
320
321REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_CPU), StageClearOp);
322REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_DEFAULT),
323 StageClearOp);
324
325} // namespace tensorflow
326