1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #include "tensorflow/core/framework/bfloat16.h" |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/framework/register_types.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/tensor_shape.h" |
23 | #include "tensorflow/core/kernels/fill_functor.h" |
24 | #include "tensorflow/core/kernels/inplace_ops_functor.h" |
25 | #include "tensorflow/core/lib/core/status.h" |
26 | |
27 | namespace tensorflow { |
28 | typedef Eigen::ThreadPoolDevice CPUDevice; |
29 | |
30 | namespace functor { |
31 | |
32 | template <typename Device, typename T> |
33 | Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32_t loc, |
34 | Tensor* output) { |
35 | auto Tvalue = value.shaped<T, 2>({1, value.NumElements()}); |
36 | auto Toutput = output->flat_outer_dims<T>(); |
37 | auto nrows = Toutput.dimension(0); |
38 | auto r = (loc % nrows + nrows) % nrows; // Guard index range. |
39 | Toutput.template chip<0>(r).device(d) = Tvalue.template chip<0>(0); |
40 | return OkStatus(); |
41 | } |
42 | |
43 | template <> |
44 | Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32_t loc, |
45 | Tensor* output) { |
46 | CHECK_EQ(value.dtype(), output->dtype()); |
47 | switch (value.dtype()) { |
48 | #define CASE(type) \ |
49 | case DataTypeToEnum<type>::value: \ |
50 | return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output); |
51 | TF_CALL_POD_TYPES(CASE); |
52 | TF_CALL_tstring(CASE); |
53 | TF_CALL_variant(CASE); |
54 | #undef CASE |
55 | default: |
56 | return errors::InvalidArgument("Unsupported data type: " , |
57 | DataTypeString(value.dtype())); |
58 | } |
59 | } |
60 | |
61 | } // end namespace functor |
62 | |
63 | namespace { |
64 | |
65 | template <typename Device> |
66 | class ParallelConcatUpdate : public OpKernel { |
67 | public: |
68 | explicit ParallelConcatUpdate(OpKernelConstruction* ctx) : OpKernel(ctx) { |
69 | OP_REQUIRES_OK(ctx, ctx->GetAttr("loc" , &loc_)); |
70 | } |
71 | |
72 | void Compute(OpKernelContext* ctx) override { |
73 | auto value = ctx->input(0); |
74 | // Value should be at least rank 1. Also the 0th dimension should be |
75 | // at least loc_. |
76 | OP_REQUIRES(ctx, value.dims() >= 1, |
77 | errors::InvalidArgument("value should be at least rank 1." )); |
78 | OP_REQUIRES( |
79 | ctx, value.dim_size(0) > loc_, |
80 | errors::InvalidArgument("0th dimension of value = " , value.dim_size(0), |
81 | " is less than loc_=" , loc_)); |
82 | |
83 | auto update = ctx->input(1); |
84 | |
85 | OP_REQUIRES( |
86 | ctx, value.dims() == update.dims(), |
87 | errors::InvalidArgument("value and update shape doesn't match: " , |
88 | value.shape().DebugString(), " vs. " , |
89 | update.shape().DebugString())); |
90 | for (int i = 1; i < value.dims(); ++i) { |
91 | OP_REQUIRES( |
92 | ctx, value.dim_size(i) == update.dim_size(i), |
93 | errors::InvalidArgument("value and update shape doesn't match " , |
94 | value.shape().DebugString(), " vs. " , |
95 | update.shape().DebugString())); |
96 | } |
97 | OP_REQUIRES(ctx, 1 == update.dim_size(0), |
98 | errors::InvalidArgument("update shape doesn't match: " , |
99 | update.shape().DebugString())); |
100 | |
101 | Tensor output = value; // This creates an alias intentionally. |
102 | const auto& d = ctx->eigen_device<Device>(); |
103 | OP_REQUIRES_OK( |
104 | ctx, ::tensorflow::functor::DoParallelConcat(d, update, loc_, &output)); |
105 | ctx->set_output(0, output); |
106 | } |
107 | |
108 | private: |
109 | int32 loc_; |
110 | }; |
111 | |
112 | template <typename Device, typename T> |
113 | class ParallelConcatStart : public OpKernel { |
114 | public: |
115 | explicit ParallelConcatStart(OpKernelConstruction* ctx) : OpKernel(ctx) { |
116 | OP_REQUIRES_OK(ctx, ctx->GetAttr("shape" , &shape_)); |
117 | } |
118 | |
119 | void Compute(OpKernelContext* ctx) override { |
120 | Tensor* out = nullptr; |
121 | // We do not know whether the output will be used on GPU. Setting it to be |
122 | // gpu-compatible for now. |
123 | AllocatorAttributes attr; |
124 | attr.set_gpu_compatible(true); |
125 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape_, &out, attr)); |
126 | } |
127 | |
128 | private: |
129 | TensorShape shape_; |
130 | }; |
131 | |
132 | class FailureKernel : public OpKernel { |
133 | public: |
134 | explicit FailureKernel(OpKernelConstruction* ctx) : OpKernel(ctx) { |
135 | OP_REQUIRES_OK(ctx, |
136 | errors::Internal("Found instance of parallel_stack which " |
137 | "could not be properly replaced." )); |
138 | } |
139 | |
140 | void Compute(OpKernelContext*) override {} |
141 | }; |
142 | |
143 | #define REGISTER(type) \ |
144 | REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ |
145 | .Device(DEVICE_CPU) \ |
146 | .TypeConstraint<type>("T"), \ |
147 | ParallelConcatUpdate<CPUDevice>); |
148 | TF_CALL_POD_STRING_TYPES(REGISTER) |
149 | #undef REGISTER |
150 | |
151 | #define REGISTER_EMPTY(type) \ |
152 | REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ |
153 | .Device(DEVICE_CPU) \ |
154 | .TypeConstraint<type>("dtype"), \ |
155 | ParallelConcatStart<CPUDevice, type>) |
156 | |
157 | TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY) |
158 | #undef REGISTER_EMPTY |
159 | |
160 | #define REGISTER_PARALLEL_CONCAT(type) \ |
161 | REGISTER_KERNEL_BUILDER( \ |
162 | Name("ParallelConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
163 | FailureKernel); |
164 | TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT); |
165 | #undef REGISTER_PARALLEL_CONCAT |
166 | |
167 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
168 | |
169 | typedef Eigen::GpuDevice GPUDevice; |
170 | |
171 | #define REGISTER_PARALLEL_CONCAT_START(type) \ |
172 | REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ |
173 | .Device(DEVICE_GPU) \ |
174 | .TypeConstraint<type>("dtype"), \ |
175 | ParallelConcatStart<GPUDevice, type>); |
176 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT_START) |
177 | #undef REGISTER_PARALLEL_CONCAT_START |
178 | |
179 | #define REGISTER_PARALLEL_CONCAT(type) \ |
180 | REGISTER_KERNEL_BUILDER( \ |
181 | Name("ParallelConcat").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
182 | FailureKernel); |
183 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT); |
184 | #undef REGISTER_PARALLEL_CONCAT |
185 | |
186 | #define REGISTER(type) \ |
187 | REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ |
188 | .Device(DEVICE_GPU) \ |
189 | .TypeConstraint<type>("T"), \ |
190 | ParallelConcatUpdate<GPUDevice>); |
191 | TF_CALL_GPU_NUMBER_TYPES(REGISTER) |
192 | #undef REGISTER |
193 | |
194 | // Register versions that operate on int32 data on the CPU even though the op |
195 | // has been placed on the GPU |
196 | |
197 | REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate" ) |
198 | .Device(DEVICE_GPU) |
199 | .HostMemory("value" ) |
200 | .HostMemory("update" ) |
201 | .HostMemory("output" ) |
202 | .TypeConstraint<int32>("T" ), |
203 | ParallelConcatUpdate<CPUDevice>); |
204 | #endif |
205 | |
206 | class InplaceOpBase : public OpKernel { |
207 | public: |
208 | explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
209 | |
210 | void Compute(OpKernelContext* ctx) override { |
211 | auto x = ctx->input(0); |
212 | auto i = ctx->input(1); |
213 | auto v = ctx->input(2); |
214 | |
215 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(i.shape()), |
216 | errors::InvalidArgument("i must be a vector. " , |
217 | i.shape().DebugString())); |
218 | OP_REQUIRES(ctx, x.dims() == v.dims(), |
219 | errors::InvalidArgument( |
220 | "x and v shape doesn't match (ranks differ): " , |
221 | x.shape().DebugString(), " vs. " , v.shape().DebugString())); |
222 | for (int i = 1; i < x.dims(); ++i) { |
223 | OP_REQUIRES( |
224 | ctx, x.dim_size(i) == v.dim_size(i), |
225 | errors::InvalidArgument("x and v shape doesn't match at index " , i, |
226 | " : " , x.shape().DebugString(), " vs. " , |
227 | v.shape().DebugString())); |
228 | } |
229 | OP_REQUIRES(ctx, i.dim_size(0) == v.dim_size(0), |
230 | errors::InvalidArgument( |
231 | "i and x shape doesn't match at index 0: " , |
232 | i.shape().DebugString(), " vs. " , v.shape().DebugString())); |
233 | |
234 | Tensor y = x; // This creates an alias intentionally. |
235 | // Skip processing if tensors are empty. |
236 | if (x.NumElements() > 0 && v.NumElements() > 0) { |
237 | OP_REQUIRES_OK(ctx, DoCompute(ctx, i, v, &y)); |
238 | } |
239 | ctx->set_output(0, y); |
240 | } |
241 | |
242 | protected: |
243 | virtual Status DoCompute(OpKernelContext* ctx, const Tensor& i, |
244 | const Tensor& v, Tensor* y) = 0; |
245 | }; |
246 | |
247 | } // end namespace |
248 | |
249 | namespace functor { |
250 | |
251 | template <typename T> |
252 | void DoInplaceOp(const CPUDevice& d, InplaceOpType op, const Tensor& i, |
253 | const Tensor& v, Tensor* y) { |
254 | auto Ti = i.flat<int32>(); |
255 | auto Tv = v.flat_outer_dims<T>(); |
256 | auto Ty = y->flat_outer_dims<T>(); |
257 | auto nrows = Ty.dimension(0); |
258 | for (int64_t j = 0; j < Ti.size(); ++j) { |
259 | auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range. |
260 | switch (op) { |
261 | case I_UPDATE: |
262 | Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j); |
263 | break; |
264 | case I_ADD: |
265 | Ty.template chip<0>(r).device(d) += Tv.template chip<0>(j); |
266 | break; |
267 | case I_SUB: |
268 | Ty.template chip<0>(r).device(d) -= Tv.template chip<0>(j); |
269 | break; |
270 | } |
271 | } |
272 | } |
273 | |
274 | // String type only supports inplace update. |
275 | void DoInplaceStringUpdateOp(const CPUDevice& d, const Tensor& i, |
276 | const Tensor& v, Tensor* y) { |
277 | auto Ti = i.flat<int32>(); |
278 | auto Tv = v.flat_outer_dims<tstring>(); |
279 | auto Ty = y->flat_outer_dims<tstring>(); |
280 | auto nrows = Ty.dimension(0); |
281 | for (int64_t j = 0; j < Ti.size(); ++j) { |
282 | auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range. |
283 | Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j); |
284 | } |
285 | } |
286 | |
287 | template <> |
288 | Status DoInplace(const CPUDevice& device, InplaceOpType op, const Tensor& i, |
289 | const Tensor& v, Tensor* y) { |
290 | CHECK_EQ(v.dtype(), y->dtype()); |
291 | if (op == I_UPDATE) { |
292 | if (v.dtype() == DT_STRING) { |
293 | DoInplaceStringUpdateOp(device, i, v, y); |
294 | return OkStatus(); |
295 | } else if (v.dtype() == DT_BOOL) { |
296 | DoInplaceOp<bool>(device, op, i, v, y); |
297 | return OkStatus(); |
298 | } |
299 | } |
300 | switch (v.dtype()) { |
301 | #define CASE(type) \ |
302 | case DataTypeToEnum<type>::value: \ |
303 | DoInplaceOp<type>(device, op, i, v, y); \ |
304 | break; |
305 | TF_CALL_NUMBER_TYPES(CASE); |
306 | #undef CASE |
307 | default: |
308 | return errors::InvalidArgument("Unsupported data type: " , |
309 | DataTypeString(v.dtype())); |
310 | } |
311 | return OkStatus(); |
312 | } |
313 | |
314 | } // end namespace functor |
315 | |
316 | namespace { |
317 | template <typename Device, functor::InplaceOpType op> |
318 | class InplaceOp : public InplaceOpBase { |
319 | public: |
320 | explicit InplaceOp(OpKernelConstruction* ctx) : InplaceOpBase(ctx) {} |
321 | |
322 | protected: |
323 | Status DoCompute(OpKernelContext* ctx, const Tensor& i, const Tensor& v, |
324 | Tensor* y) override { |
325 | const auto& d = ctx->eigen_device<Device>(); |
326 | return ::tensorflow::functor::DoInplace(d, op, i, v, y); |
327 | } |
328 | }; |
329 | |
330 | class CopyOpBase : public OpKernel { |
331 | public: |
332 | explicit CopyOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
333 | |
334 | void Compute(OpKernelContext* ctx) override { |
335 | auto x = ctx->input(0); |
336 | Tensor* y; |
337 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y)); |
338 | OP_REQUIRES_OK(ctx, DoCompute(ctx, x, y)); |
339 | } |
340 | |
341 | protected: |
342 | virtual Status DoCompute(OpKernelContext* ctx, const Tensor& x, |
343 | Tensor* y) = 0; |
344 | }; |
345 | |
346 | template <typename Device> |
347 | class CopyOp : public CopyOpBase { |
348 | public: |
349 | explicit CopyOp(OpKernelConstruction* ctx) : CopyOpBase(ctx) {} |
350 | |
351 | protected: |
352 | Status DoCompute(OpKernelContext* ctx, const Tensor& x, Tensor* y) override { |
353 | const auto& d = ctx->eigen_device<Device>(); |
354 | return ::tensorflow::functor::DoCopy(d, x, y); |
355 | } |
356 | }; |
357 | |
358 | } // end namespace |
359 | |
360 | namespace functor { |
361 | |
362 | typedef Eigen::ThreadPoolDevice CPUDevice; |
363 | |
364 | template <> |
365 | Status DoCopy(const CPUDevice& device, const Tensor& x, Tensor* y) { |
366 | CHECK_EQ(x.dtype(), y->dtype()); |
367 | switch (x.dtype()) { |
368 | #define CASE(type) \ |
369 | case DataTypeToEnum<type>::value: \ |
370 | y->flat<type>().device(device) = x.flat<type>(); \ |
371 | break; |
372 | |
373 | TF_CALL_NUMBER_TYPES(CASE); |
374 | TF_CALL_bool(CASE); |
375 | TF_CALL_tstring(CASE); |
376 | #undef CASE |
377 | default: |
378 | return errors::InvalidArgument("Unsupported data type: " , |
379 | DataTypeString(x.dtype())); |
380 | } |
381 | return OkStatus(); |
382 | } |
383 | |
384 | } // end namespace functor |
385 | |
386 | namespace { |
387 | template <typename Device, typename T> |
388 | class EmptyOp : public OpKernel { |
389 | public: |
390 | explicit EmptyOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
391 | OP_REQUIRES_OK(ctx, ctx->GetAttr("init" , &init_)); |
392 | } |
393 | |
394 | void Compute(OpKernelContext* ctx) override { |
395 | const Tensor& shape = ctx->input(0); |
396 | OP_REQUIRES( |
397 | ctx, TensorShapeUtils::IsVector(shape.shape()), |
398 | errors::InvalidArgument("shape must be a vector of int32, got shape " , |
399 | shape.shape().DebugString())); |
400 | auto dims = shape.flat<int32>(); |
401 | TensorShape out_shape; |
402 | OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape( |
403 | reinterpret_cast<const int32*>(dims.data()), |
404 | dims.size(), &out_shape)); |
405 | Tensor* out = nullptr; |
406 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); |
407 | |
408 | if (init_) { |
409 | functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(), |
410 | out->flat<T>()); |
411 | } |
412 | } |
413 | |
414 | private: |
415 | bool init_; |
416 | }; |
417 | |
418 | REGISTER_KERNEL_BUILDER(Name("InplaceUpdate" ).Device(DEVICE_CPU), |
419 | InplaceOp<CPUDevice, functor::I_UPDATE>); |
420 | REGISTER_KERNEL_BUILDER(Name("InplaceAdd" ).Device(DEVICE_CPU), |
421 | InplaceOp<CPUDevice, functor::I_ADD>); |
422 | REGISTER_KERNEL_BUILDER(Name("InplaceSub" ).Device(DEVICE_CPU), |
423 | InplaceOp<CPUDevice, functor::I_SUB>); |
424 | REGISTER_KERNEL_BUILDER(Name("DeepCopy" ).Device(DEVICE_CPU), CopyOp<CPUDevice>); |
425 | |
426 | #define REGISTER_EMPTY(type, dev) \ |
427 | REGISTER_KERNEL_BUILDER(Name("Empty") \ |
428 | .Device(DEVICE_##dev) \ |
429 | .HostMemory("shape") \ |
430 | .TypeConstraint<type>("dtype"), \ |
431 | EmptyOp<dev##Device, type>) |
432 | |
433 | REGISTER_EMPTY(float, CPU) |
434 | REGISTER_EMPTY(bfloat16, CPU) |
435 | REGISTER_EMPTY(double, CPU) |
436 | REGISTER_EMPTY(Eigen::half, CPU) |
437 | REGISTER_EMPTY(tstring, CPU) |
438 | REGISTER_EMPTY(int32, CPU) |
439 | REGISTER_EMPTY(int64_t, CPU) |
440 | REGISTER_EMPTY(bool, CPU) |
441 | REGISTER_EMPTY(uint8, CPU) |
442 | |
443 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
444 | |
445 | typedef Eigen::GpuDevice GPUDevice; |
446 | |
447 | #define REGISTER(TYPE) \ |
448 | REGISTER_KERNEL_BUILDER( \ |
449 | Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ |
450 | InplaceOp<GPUDevice, functor::I_UPDATE>); \ |
451 | REGISTER_KERNEL_BUILDER( \ |
452 | Name("InplaceAdd").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ |
453 | InplaceOp<GPUDevice, functor::I_ADD>); \ |
454 | REGISTER_KERNEL_BUILDER( \ |
455 | Name("InplaceSub").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ |
456 | InplaceOp<GPUDevice, functor::I_SUB>); \ |
457 | REGISTER_KERNEL_BUILDER( \ |
458 | Name("DeepCopy").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ |
459 | CopyOp<GPUDevice>); |
460 | |
461 | REGISTER_KERNEL_BUILDER( |
462 | Name("InplaceUpdate" ).Device(DEVICE_GPU).TypeConstraint<bool>("T" ), |
463 | InplaceOp<GPUDevice, functor::I_UPDATE>); |
464 | REGISTER(float); |
465 | REGISTER(double); |
466 | REGISTER(Eigen::half); |
467 | REGISTER(int64_t); |
468 | |
469 | REGISTER_EMPTY(float, GPU); |
470 | REGISTER_EMPTY(double, GPU); |
471 | REGISTER_EMPTY(Eigen::half, GPU); |
472 | REGISTER_EMPTY(int64_t, GPU); |
473 | REGISTER_EMPTY(int32, GPU); |
474 | |
475 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
476 | |
477 | REGISTER_KERNEL_BUILDER(Name("InplaceUpdate" ) |
478 | .Device(DEVICE_DEFAULT) |
479 | .HostMemory("x" ) |
480 | .HostMemory("i" ) |
481 | .HostMemory("v" ) |
482 | .HostMemory("y" ) |
483 | .TypeConstraint<int32>("T" ), |
484 | InplaceOp<CPUDevice, functor::I_UPDATE>); |
485 | REGISTER_KERNEL_BUILDER(Name("InplaceAdd" ) |
486 | .Device(DEVICE_DEFAULT) |
487 | .HostMemory("x" ) |
488 | .HostMemory("i" ) |
489 | .HostMemory("v" ) |
490 | .HostMemory("y" ) |
491 | .TypeConstraint<int32>("T" ), |
492 | InplaceOp<CPUDevice, functor::I_ADD>); |
493 | REGISTER_KERNEL_BUILDER(Name("InplaceSub" ) |
494 | .Device(DEVICE_DEFAULT) |
495 | .HostMemory("x" ) |
496 | .HostMemory("i" ) |
497 | .HostMemory("v" ) |
498 | .HostMemory("y" ) |
499 | .TypeConstraint<int32>("T" ), |
500 | InplaceOp<CPUDevice, functor::I_SUB>); |
501 | |
502 | REGISTER_KERNEL_BUILDER(Name("DeepCopy" ) |
503 | .Device(DEVICE_DEFAULT) |
504 | .HostMemory("x" ) |
505 | .HostMemory("y" ) |
506 | .TypeConstraint<int32>("T" ), |
507 | CopyOp<CPUDevice>); |
508 | |
509 | } // end namespace |
510 | } // end namespace tensorflow |
511 | |