1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#include "tensorflow/core/framework/bfloat16.h"
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/register_types.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/framework/tensor_shape.h"
23#include "tensorflow/core/kernels/fill_functor.h"
24#include "tensorflow/core/kernels/inplace_ops_functor.h"
25#include "tensorflow/core/lib/core/status.h"
26
27namespace tensorflow {
28typedef Eigen::ThreadPoolDevice CPUDevice;
29
30namespace functor {
31
32template <typename Device, typename T>
33Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32_t loc,
34 Tensor* output) {
35 auto Tvalue = value.shaped<T, 2>({1, value.NumElements()});
36 auto Toutput = output->flat_outer_dims<T>();
37 auto nrows = Toutput.dimension(0);
38 auto r = (loc % nrows + nrows) % nrows; // Guard index range.
39 Toutput.template chip<0>(r).device(d) = Tvalue.template chip<0>(0);
40 return OkStatus();
41}
42
43template <>
44Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32_t loc,
45 Tensor* output) {
46 CHECK_EQ(value.dtype(), output->dtype());
47 switch (value.dtype()) {
48#define CASE(type) \
49 case DataTypeToEnum<type>::value: \
50 return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output);
51 TF_CALL_POD_TYPES(CASE);
52 TF_CALL_tstring(CASE);
53 TF_CALL_variant(CASE);
54#undef CASE
55 default:
56 return errors::InvalidArgument("Unsupported data type: ",
57 DataTypeString(value.dtype()));
58 }
59}
60
61} // end namespace functor
62
63namespace {
64
65template <typename Device>
66class ParallelConcatUpdate : public OpKernel {
67 public:
68 explicit ParallelConcatUpdate(OpKernelConstruction* ctx) : OpKernel(ctx) {
69 OP_REQUIRES_OK(ctx, ctx->GetAttr("loc", &loc_));
70 }
71
72 void Compute(OpKernelContext* ctx) override {
73 auto value = ctx->input(0);
74 // Value should be at least rank 1. Also the 0th dimension should be
75 // at least loc_.
76 OP_REQUIRES(ctx, value.dims() >= 1,
77 errors::InvalidArgument("value should be at least rank 1."));
78 OP_REQUIRES(
79 ctx, value.dim_size(0) > loc_,
80 errors::InvalidArgument("0th dimension of value = ", value.dim_size(0),
81 " is less than loc_=", loc_));
82
83 auto update = ctx->input(1);
84
85 OP_REQUIRES(
86 ctx, value.dims() == update.dims(),
87 errors::InvalidArgument("value and update shape doesn't match: ",
88 value.shape().DebugString(), " vs. ",
89 update.shape().DebugString()));
90 for (int i = 1; i < value.dims(); ++i) {
91 OP_REQUIRES(
92 ctx, value.dim_size(i) == update.dim_size(i),
93 errors::InvalidArgument("value and update shape doesn't match ",
94 value.shape().DebugString(), " vs. ",
95 update.shape().DebugString()));
96 }
97 OP_REQUIRES(ctx, 1 == update.dim_size(0),
98 errors::InvalidArgument("update shape doesn't match: ",
99 update.shape().DebugString()));
100
101 Tensor output = value; // This creates an alias intentionally.
102 const auto& d = ctx->eigen_device<Device>();
103 OP_REQUIRES_OK(
104 ctx, ::tensorflow::functor::DoParallelConcat(d, update, loc_, &output));
105 ctx->set_output(0, output);
106 }
107
108 private:
109 int32 loc_;
110};
111
112template <typename Device, typename T>
113class ParallelConcatStart : public OpKernel {
114 public:
115 explicit ParallelConcatStart(OpKernelConstruction* ctx) : OpKernel(ctx) {
116 OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
117 }
118
119 void Compute(OpKernelContext* ctx) override {
120 Tensor* out = nullptr;
121 // We do not know whether the output will be used on GPU. Setting it to be
122 // gpu-compatible for now.
123 AllocatorAttributes attr;
124 attr.set_gpu_compatible(true);
125 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, shape_, &out, attr));
126 }
127
128 private:
129 TensorShape shape_;
130};
131
132class FailureKernel : public OpKernel {
133 public:
134 explicit FailureKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {
135 OP_REQUIRES_OK(ctx,
136 errors::Internal("Found instance of parallel_stack which "
137 "could not be properly replaced."));
138 }
139
140 void Compute(OpKernelContext*) override {}
141};
142
143#define REGISTER(type) \
144 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \
145 .Device(DEVICE_CPU) \
146 .TypeConstraint<type>("T"), \
147 ParallelConcatUpdate<CPUDevice>);
148TF_CALL_POD_STRING_TYPES(REGISTER)
149#undef REGISTER
150
151#define REGISTER_EMPTY(type) \
152 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
153 .Device(DEVICE_CPU) \
154 .TypeConstraint<type>("dtype"), \
155 ParallelConcatStart<CPUDevice, type>)
156
157TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY)
158#undef REGISTER_EMPTY
159
160#define REGISTER_PARALLEL_CONCAT(type) \
161 REGISTER_KERNEL_BUILDER( \
162 Name("ParallelConcat").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
163 FailureKernel);
164TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT);
165#undef REGISTER_PARALLEL_CONCAT
166
167#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
168
169typedef Eigen::GpuDevice GPUDevice;
170
171#define REGISTER_PARALLEL_CONCAT_START(type) \
172 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
173 .Device(DEVICE_GPU) \
174 .TypeConstraint<type>("dtype"), \
175 ParallelConcatStart<GPUDevice, type>);
176TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT_START)
177#undef REGISTER_PARALLEL_CONCAT_START
178
179#define REGISTER_PARALLEL_CONCAT(type) \
180 REGISTER_KERNEL_BUILDER( \
181 Name("ParallelConcat").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
182 FailureKernel);
183TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_CONCAT);
184#undef REGISTER_PARALLEL_CONCAT
185
186#define REGISTER(type) \
187 REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \
188 .Device(DEVICE_GPU) \
189 .TypeConstraint<type>("T"), \
190 ParallelConcatUpdate<GPUDevice>);
191TF_CALL_GPU_NUMBER_TYPES(REGISTER)
192#undef REGISTER
193
194// Register versions that operate on int32 data on the CPU even though the op
195// has been placed on the GPU
196
197REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
198 .Device(DEVICE_GPU)
199 .HostMemory("value")
200 .HostMemory("update")
201 .HostMemory("output")
202 .TypeConstraint<int32>("T"),
203 ParallelConcatUpdate<CPUDevice>);
204#endif
205
206class InplaceOpBase : public OpKernel {
207 public:
208 explicit InplaceOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
209
210 void Compute(OpKernelContext* ctx) override {
211 auto x = ctx->input(0);
212 auto i = ctx->input(1);
213 auto v = ctx->input(2);
214
215 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(i.shape()),
216 errors::InvalidArgument("i must be a vector. ",
217 i.shape().DebugString()));
218 OP_REQUIRES(ctx, x.dims() == v.dims(),
219 errors::InvalidArgument(
220 "x and v shape doesn't match (ranks differ): ",
221 x.shape().DebugString(), " vs. ", v.shape().DebugString()));
222 for (int i = 1; i < x.dims(); ++i) {
223 OP_REQUIRES(
224 ctx, x.dim_size(i) == v.dim_size(i),
225 errors::InvalidArgument("x and v shape doesn't match at index ", i,
226 " : ", x.shape().DebugString(), " vs. ",
227 v.shape().DebugString()));
228 }
229 OP_REQUIRES(ctx, i.dim_size(0) == v.dim_size(0),
230 errors::InvalidArgument(
231 "i and x shape doesn't match at index 0: ",
232 i.shape().DebugString(), " vs. ", v.shape().DebugString()));
233
234 Tensor y = x; // This creates an alias intentionally.
235 // Skip processing if tensors are empty.
236 if (x.NumElements() > 0 && v.NumElements() > 0) {
237 OP_REQUIRES_OK(ctx, DoCompute(ctx, i, v, &y));
238 }
239 ctx->set_output(0, y);
240 }
241
242 protected:
243 virtual Status DoCompute(OpKernelContext* ctx, const Tensor& i,
244 const Tensor& v, Tensor* y) = 0;
245};
246
247} // end namespace
248
249namespace functor {
250
251template <typename T>
252void DoInplaceOp(const CPUDevice& d, InplaceOpType op, const Tensor& i,
253 const Tensor& v, Tensor* y) {
254 auto Ti = i.flat<int32>();
255 auto Tv = v.flat_outer_dims<T>();
256 auto Ty = y->flat_outer_dims<T>();
257 auto nrows = Ty.dimension(0);
258 for (int64_t j = 0; j < Ti.size(); ++j) {
259 auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range.
260 switch (op) {
261 case I_UPDATE:
262 Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j);
263 break;
264 case I_ADD:
265 Ty.template chip<0>(r).device(d) += Tv.template chip<0>(j);
266 break;
267 case I_SUB:
268 Ty.template chip<0>(r).device(d) -= Tv.template chip<0>(j);
269 break;
270 }
271 }
272}
273
274// String type only supports inplace update.
275void DoInplaceStringUpdateOp(const CPUDevice& d, const Tensor& i,
276 const Tensor& v, Tensor* y) {
277 auto Ti = i.flat<int32>();
278 auto Tv = v.flat_outer_dims<tstring>();
279 auto Ty = y->flat_outer_dims<tstring>();
280 auto nrows = Ty.dimension(0);
281 for (int64_t j = 0; j < Ti.size(); ++j) {
282 auto r = (Ti(j) % nrows + nrows) % nrows; // Guard index range.
283 Ty.template chip<0>(r).device(d) = Tv.template chip<0>(j);
284 }
285}
286
287template <>
288Status DoInplace(const CPUDevice& device, InplaceOpType op, const Tensor& i,
289 const Tensor& v, Tensor* y) {
290 CHECK_EQ(v.dtype(), y->dtype());
291 if (op == I_UPDATE) {
292 if (v.dtype() == DT_STRING) {
293 DoInplaceStringUpdateOp(device, i, v, y);
294 return OkStatus();
295 } else if (v.dtype() == DT_BOOL) {
296 DoInplaceOp<bool>(device, op, i, v, y);
297 return OkStatus();
298 }
299 }
300 switch (v.dtype()) {
301#define CASE(type) \
302 case DataTypeToEnum<type>::value: \
303 DoInplaceOp<type>(device, op, i, v, y); \
304 break;
305 TF_CALL_NUMBER_TYPES(CASE);
306#undef CASE
307 default:
308 return errors::InvalidArgument("Unsupported data type: ",
309 DataTypeString(v.dtype()));
310 }
311 return OkStatus();
312}
313
314} // end namespace functor
315
316namespace {
317template <typename Device, functor::InplaceOpType op>
318class InplaceOp : public InplaceOpBase {
319 public:
320 explicit InplaceOp(OpKernelConstruction* ctx) : InplaceOpBase(ctx) {}
321
322 protected:
323 Status DoCompute(OpKernelContext* ctx, const Tensor& i, const Tensor& v,
324 Tensor* y) override {
325 const auto& d = ctx->eigen_device<Device>();
326 return ::tensorflow::functor::DoInplace(d, op, i, v, y);
327 }
328};
329
330class CopyOpBase : public OpKernel {
331 public:
332 explicit CopyOpBase(OpKernelConstruction* ctx) : OpKernel(ctx) {}
333
334 void Compute(OpKernelContext* ctx) override {
335 auto x = ctx->input(0);
336 Tensor* y;
337 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, x.shape(), &y));
338 OP_REQUIRES_OK(ctx, DoCompute(ctx, x, y));
339 }
340
341 protected:
342 virtual Status DoCompute(OpKernelContext* ctx, const Tensor& x,
343 Tensor* y) = 0;
344};
345
346template <typename Device>
347class CopyOp : public CopyOpBase {
348 public:
349 explicit CopyOp(OpKernelConstruction* ctx) : CopyOpBase(ctx) {}
350
351 protected:
352 Status DoCompute(OpKernelContext* ctx, const Tensor& x, Tensor* y) override {
353 const auto& d = ctx->eigen_device<Device>();
354 return ::tensorflow::functor::DoCopy(d, x, y);
355 }
356};
357
358} // end namespace
359
360namespace functor {
361
362typedef Eigen::ThreadPoolDevice CPUDevice;
363
364template <>
365Status DoCopy(const CPUDevice& device, const Tensor& x, Tensor* y) {
366 CHECK_EQ(x.dtype(), y->dtype());
367 switch (x.dtype()) {
368#define CASE(type) \
369 case DataTypeToEnum<type>::value: \
370 y->flat<type>().device(device) = x.flat<type>(); \
371 break;
372
373 TF_CALL_NUMBER_TYPES(CASE);
374 TF_CALL_bool(CASE);
375 TF_CALL_tstring(CASE);
376#undef CASE
377 default:
378 return errors::InvalidArgument("Unsupported data type: ",
379 DataTypeString(x.dtype()));
380 }
381 return OkStatus();
382}
383
384} // end namespace functor
385
386namespace {
387template <typename Device, typename T>
388class EmptyOp : public OpKernel {
389 public:
390 explicit EmptyOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
391 OP_REQUIRES_OK(ctx, ctx->GetAttr("init", &init_));
392 }
393
394 void Compute(OpKernelContext* ctx) override {
395 const Tensor& shape = ctx->input(0);
396 OP_REQUIRES(
397 ctx, TensorShapeUtils::IsVector(shape.shape()),
398 errors::InvalidArgument("shape must be a vector of int32, got shape ",
399 shape.shape().DebugString()));
400 auto dims = shape.flat<int32>();
401 TensorShape out_shape;
402 OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
403 reinterpret_cast<const int32*>(dims.data()),
404 dims.size(), &out_shape));
405 Tensor* out = nullptr;
406 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
407
408 if (init_) {
409 functor::SetZeroFunctor<Device, T>()(ctx->eigen_device<Device>(),
410 out->flat<T>());
411 }
412 }
413
414 private:
415 bool init_;
416};
417
418REGISTER_KERNEL_BUILDER(Name("InplaceUpdate").Device(DEVICE_CPU),
419 InplaceOp<CPUDevice, functor::I_UPDATE>);
420REGISTER_KERNEL_BUILDER(Name("InplaceAdd").Device(DEVICE_CPU),
421 InplaceOp<CPUDevice, functor::I_ADD>);
422REGISTER_KERNEL_BUILDER(Name("InplaceSub").Device(DEVICE_CPU),
423 InplaceOp<CPUDevice, functor::I_SUB>);
424REGISTER_KERNEL_BUILDER(Name("DeepCopy").Device(DEVICE_CPU), CopyOp<CPUDevice>);
425
426#define REGISTER_EMPTY(type, dev) \
427 REGISTER_KERNEL_BUILDER(Name("Empty") \
428 .Device(DEVICE_##dev) \
429 .HostMemory("shape") \
430 .TypeConstraint<type>("dtype"), \
431 EmptyOp<dev##Device, type>)
432
433REGISTER_EMPTY(float, CPU)
434REGISTER_EMPTY(bfloat16, CPU)
435REGISTER_EMPTY(double, CPU)
436REGISTER_EMPTY(Eigen::half, CPU)
437REGISTER_EMPTY(tstring, CPU)
438REGISTER_EMPTY(int32, CPU)
439REGISTER_EMPTY(int64_t, CPU)
440REGISTER_EMPTY(bool, CPU)
441REGISTER_EMPTY(uint8, CPU)
442
443#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
444
445typedef Eigen::GpuDevice GPUDevice;
446
447#define REGISTER(TYPE) \
448 REGISTER_KERNEL_BUILDER( \
449 Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
450 InplaceOp<GPUDevice, functor::I_UPDATE>); \
451 REGISTER_KERNEL_BUILDER( \
452 Name("InplaceAdd").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
453 InplaceOp<GPUDevice, functor::I_ADD>); \
454 REGISTER_KERNEL_BUILDER( \
455 Name("InplaceSub").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
456 InplaceOp<GPUDevice, functor::I_SUB>); \
457 REGISTER_KERNEL_BUILDER( \
458 Name("DeepCopy").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
459 CopyOp<GPUDevice>);
460
461REGISTER_KERNEL_BUILDER(
462 Name("InplaceUpdate").Device(DEVICE_GPU).TypeConstraint<bool>("T"),
463 InplaceOp<GPUDevice, functor::I_UPDATE>);
464REGISTER(float);
465REGISTER(double);
466REGISTER(Eigen::half);
467REGISTER(int64_t);
468
469REGISTER_EMPTY(float, GPU);
470REGISTER_EMPTY(double, GPU);
471REGISTER_EMPTY(Eigen::half, GPU);
472REGISTER_EMPTY(int64_t, GPU);
473REGISTER_EMPTY(int32, GPU);
474
475#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
476
477REGISTER_KERNEL_BUILDER(Name("InplaceUpdate")
478 .Device(DEVICE_DEFAULT)
479 .HostMemory("x")
480 .HostMemory("i")
481 .HostMemory("v")
482 .HostMemory("y")
483 .TypeConstraint<int32>("T"),
484 InplaceOp<CPUDevice, functor::I_UPDATE>);
485REGISTER_KERNEL_BUILDER(Name("InplaceAdd")
486 .Device(DEVICE_DEFAULT)
487 .HostMemory("x")
488 .HostMemory("i")
489 .HostMemory("v")
490 .HostMemory("y")
491 .TypeConstraint<int32>("T"),
492 InplaceOp<CPUDevice, functor::I_ADD>);
493REGISTER_KERNEL_BUILDER(Name("InplaceSub")
494 .Device(DEVICE_DEFAULT)
495 .HostMemory("x")
496 .HostMemory("i")
497 .HostMemory("v")
498 .HostMemory("y")
499 .TypeConstraint<int32>("T"),
500 InplaceOp<CPUDevice, functor::I_SUB>);
501
502REGISTER_KERNEL_BUILDER(Name("DeepCopy")
503 .Device(DEVICE_DEFAULT)
504 .HostMemory("x")
505 .HostMemory("y")
506 .TypeConstraint<int32>("T"),
507 CopyOp<CPUDevice>);
508
509} // end namespace
510} // end namespace tensorflow
511