1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/dma_helper.h"
17#include "tensorflow/core/common_runtime/scoped_allocator.h"
18#include "tensorflow/core/common_runtime/scoped_allocator_mgr.h"
19#include "tensorflow/core/framework/allocator.h"
20#include "tensorflow/core/framework/op_kernel.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/lib/core/errors.h"
23#include "tensorflow/core/lib/core/status.h"
24
25namespace tensorflow {
26
27class ScopedAllocatorOp : public OpKernel {
28 public:
29 explicit ScopedAllocatorOp(OpKernelConstruction* context)
30 : OpKernel(context) {
31 OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
32 OP_REQUIRES_OK(context, context->GetAttr("shapes", &shapes_));
33 OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_));
34 OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
35 OP_REQUIRES_OK(context, context->GetAttr("expected_call_count",
36 &expected_call_count_));
37 device_ = context->device();
38 // Precalculate the size of the backing tensor and the offsets of
39 // the subtensors to be allocated from it, taking into account
40 // alignment considerations.
41 ScopedAllocatorMgr::PopulateFields(id_, shapes_, dtype_, &fields_);
42 size_t num_bytes = fields_.back().offset + fields_.back().bytes_allocated;
43 num_elements_ = num_bytes / DataTypeSize(dtype_);
44 OP_REQUIRES(context, num_bytes % DataTypeSize(dtype_) == 0,
45 errors::InvalidArgument(
46 "Number of bytes ", num_bytes,
47 " must be divisible by size of datatype ", dtype_));
48 }
49
50 void Compute(OpKernelContext* context) override {
51 ScopedAllocatorMgr* sam = device_->GetScopedAllocatorMgr();
52 if (!sam) {
53 context->SetStatus(errors::Internal(
54 "ScopedAllocatorMgr not supported on device ", device_->name()));
55 return;
56 }
57 Tensor* backing_tensor = nullptr;
58 AllocatorAttributes attr = context->output_alloc_attr(0);
59 Status s =
60 context->allocate_output(0, {num_elements_}, &backing_tensor, attr);
61 VLOG(1) << "_ScopedAllocatorOp " << context->op_kernel().name()
62 << " new backing tensor size " << backing_tensor->TotalBytes()
63 << " num_elements_ " << num_elements_ << " buffer "
64 << DMAHelper::buffer(backing_tensor) << " base addr "
65 << DMAHelper::base(backing_tensor);
66 if (s.ok()) {
67 s = sam->AddScopedAllocator(*backing_tensor, context->step_id(), id_,
68 name_, fields_, expected_call_count_);
69 }
70 if (!s.ok()) {
71 context->SetStatus(s);
72 }
73 }
74
75 private:
76 std::vector<TensorShape> shapes_;
77 DataType dtype_;
78 int64_t num_elements_;
79 std::vector<ScopedAllocator::Field> fields_;
80 string name_;
81 int32 id_;
82 int32 expected_call_count_;
83 DeviceBase* device_;
84};
85
86REGISTER_KERNEL_BUILDER(Name("_ScopedAllocator").Device(DEVICE_CPU),
87 ScopedAllocatorOp);
88
89REGISTER_KERNEL_BUILDER(Name("_ScopedAllocator").Device(DEVICE_GPU),
90 ScopedAllocatorOp);
91
92REGISTER_KERNEL_BUILDER(Name("_ScopedAllocator").Device(DEVICE_DEFAULT),
93 ScopedAllocatorOp);
94
95class ScopedAllocatorConcatOp : public OpKernel {
96 public:
97 explicit ScopedAllocatorConcatOp(OpKernelConstruction* context)
98 : OpKernel(context) {
99 OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
100 OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
101 OP_REQUIRES_OK(context, context->GetAttr("reshape", &reshape_));
102 // These attributes are just for debugging.
103 OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_));
104 OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
105 device_ = context->device();
106 }
107
108 void Compute(OpKernelContext* context) override {
109 const Tensor& backing_tensor = context->input(0);
110 // Check that type matches.
111 OP_REQUIRES(context, backing_tensor.dtype() == dtype_,
112 errors::InvalidArgument("Backing tensor type ",
113 DataTypeString(backing_tensor.dtype()),
114 " does not match expected type ",
115 DataTypeString(dtype_)));
116 // Check that backing tensor is at least as large as the shape of the
117 // output.
118 OP_REQUIRES(context, backing_tensor.NumElements() >= shape_.num_elements(),
119 errors::InvalidArgument("Backing tensor num elements ",
120 backing_tensor.NumElements(),
121 " is not >= to expected ",
122 shape_.num_elements()));
123 Tensor output(dtype_);
124 if (reshape_) {
125 CHECK(output.CopyFrom(backing_tensor, shape_));
126 } else {
127 CHECK(output.CopyFrom(backing_tensor, backing_tensor.shape()));
128 }
129 context->set_output(0, output);
130 const TensorBuffer* backing_buf = DMAHelper::buffer(&output);
131 const void* backing_tensor_lb = backing_buf->data();
132 const void* backing_tensor_ub = static_cast<const void*>(
133 static_cast<const char*>(backing_tensor_lb) + backing_buf->size());
134 // Check that all inputs lie entirely within the backing tensor.
135 for (int i = 1; i < context->num_inputs(); ++i) {
136 const TensorBuffer* input_buf = DMAHelper::buffer(&context->input(i));
137 const void* input_lb = input_buf->data();
138 const void* input_ub = static_cast<const void*>(
139 static_cast<const char*>(input_lb) + input_buf->size());
140 OP_REQUIRES(
141 context, input_lb >= backing_tensor_lb,
142 errors::InvalidArgument(
143 "Lower bound check fail for input ", i, " from node ",
144 context->op_kernel().requested_input(i), " to node ",
145 context->op_kernel().name(), " input bounds = [", input_lb, ", ",
146 input_ub, "]", " backing_tensor bounds = [", backing_tensor_lb,
147 ", ", backing_tensor_ub, "]"));
148 OP_REQUIRES(
149 context, input_ub <= backing_tensor_ub,
150 errors::InvalidArgument(
151 "Upper bound check fail for input ", i, " from node ",
152 context->op_kernel().requested_input(i), " to node ",
153 context->op_kernel().name(), " input bounds = [", input_lb, ", ",
154 input_ub, "]", " backing_tensor bounds = [", backing_tensor_lb,
155 ", ", backing_tensor_ub, "]"));
156 }
157 VLOG(1) << "_ScopedAllocatorConcatOp outputting backing tensor at "
158 << backing_buf;
159 }
160
161 private:
162 TensorShape shape_;
163 DataType dtype_;
164 string name_;
165 int32 id_;
166 bool reshape_;
167 DeviceBase* device_;
168};
169
170REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorConcat").Device(DEVICE_CPU),
171 ScopedAllocatorConcatOp);
172
173REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorConcat").Device(DEVICE_GPU),
174 ScopedAllocatorConcatOp);
175
176REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorConcat").Device(DEVICE_DEFAULT),
177 ScopedAllocatorConcatOp);
178
179class ScopedAllocatorSplitOp : public OpKernel {
180 public:
181 explicit ScopedAllocatorSplitOp(OpKernelConstruction* context)
182 : OpKernel(context) {
183 OP_REQUIRES_OK(context, context->GetAttr("T", &dtype_));
184 // This stuff is just for debugging
185 OP_REQUIRES_OK(context, context->GetAttr("sa_name", &name_));
186 OP_REQUIRES_OK(context, context->GetAttr("id", &id_));
187 device_ = context->device();
188 }
189
190 void Compute(OpKernelContext* context) override {
191 Tensor backing_copy(context->input(0));
192 // Check that type matches.
193 OP_REQUIRES(context, backing_copy.dtype() == dtype_,
194 errors::InvalidArgument("Backing tensor type ",
195 DataTypeString(backing_copy.dtype()),
196 " does not match expected type ",
197 DataTypeString(dtype_)));
198 const TensorBuffer* backing_buf = DMAHelper::buffer(&backing_copy);
199 const void* backing_tensor_lb = backing_buf->data();
200 const void* backing_tensor_ub = static_cast<const void*>(
201 static_cast<const char*>(backing_tensor_lb) + backing_buf->size());
202 for (int i = 1; i < context->num_inputs(); ++i) {
203 VLOG(1) << "_ScopedAllocatorSplitOp assigning input " << i
204 << " to output " << i - 1 << " buf addr "
205 << DMAHelper::base(&context->input(i));
206 Tensor copy(context->input(i));
207 OP_REQUIRES(context, copy.dtype() == dtype_,
208 errors::InvalidArgument("Input ", i, " tensor type ",
209 DataTypeString(copy.dtype()),
210 " does not match expected type ",
211 DataTypeString(dtype_)));
212 context->set_output(i - 1, copy);
213 const TensorBuffer* input_buf = DMAHelper::buffer(&copy);
214 const void* input_lb = input_buf->data();
215 OP_REQUIRES(
216 context, input_lb >= backing_tensor_lb,
217 errors::InvalidArgument("Lower bound check fail for input ", i,
218 " to node ", context->op_kernel().name()));
219 const void* input_ub = static_cast<const void*>(
220 static_cast<const char*>(input_lb) + input_buf->size());
221 OP_REQUIRES(
222 context, input_ub <= backing_tensor_ub,
223 errors::InvalidArgument("Upper bound check fail for input ", i,
224 " to node ", context->op_kernel().name()));
225 }
226 }
227
228 private:
229 DataType dtype_;
230 string name_;
231 int32 id_;
232 DeviceBase* device_;
233};
234
235REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorSplit").Device(DEVICE_CPU),
236 ScopedAllocatorSplitOp);
237
238REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorSplit").Device(DEVICE_GPU),
239 ScopedAllocatorSplitOp);
240
241REGISTER_KERNEL_BUILDER(Name("_ScopedAllocatorSplit").Device(DEVICE_DEFAULT),
242 ScopedAllocatorSplitOp);
243
244} // namespace tensorflow
245