1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#define EIGEN_USE_THREADS
17
18#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
19 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
20#define EIGEN_USE_GPU
21#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
22
23#include "tensorflow/core/kernels/broadcast_to_op.h"
24
25#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/framework/tensor_util.h"
30#include "tensorflow/core/framework/types.h"
31#include "tensorflow/core/util/bcast.h"
32
33namespace tensorflow {
34
35typedef Eigen::ThreadPoolDevice CPUDevice;
36typedef Eigen::GpuDevice GPUDevice;
37
38template <typename Device, typename T>
39class BroadcastToOp : public OpKernel {
40 public:
41 explicit BroadcastToOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
42
43 void Compute(OpKernelContext* ctx) override {
44 const Tensor& input_tensor = ctx->input(0);
45 const TensorShape& input_shape = input_tensor.shape();
46
47 const Tensor& shape_tensor = ctx->input(1);
48
49 TensorShape output_shape;
50 OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &output_shape));
51
52 // Handle copy.
53 if (output_shape == input_shape) {
54 ctx->set_output(0, input_tensor);
55 return;
56 }
57
58 OP_REQUIRES(ctx, input_shape.dims() <= output_shape.dims(),
59 errors::InvalidArgument(
60 "Rank of input (", input_shape.dims(),
61 ") must be no greater than rank of output shape (",
62 output_shape.dims(), ")."));
63
64 Tensor* output_tensor = nullptr;
65 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor));
66
67 // Handle broadcast from Scalar.
68 const Device& device = ctx->eigen_device<Device>();
69 if (input_shape.dims() == 0) {
70 functor::FillFunctor<Device, T>()(device, output_tensor->flat<T>(),
71 input_tensor.scalar<T>());
72 return;
73 }
74
75 // Check whether the broadcast is valid.
76 BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(output_shape),
77 /*fewer_dims_optimization=*/true);
78 OP_REQUIRES(ctx, bcast.IsValid(),
79 errors::InvalidArgument(
80 "Incompatible shapes: ", input_shape.DebugString(), " vs. ",
81 output_shape.DebugString()));
82 OP_REQUIRES(ctx, BCast::ToShape(bcast.output_shape()) == output_shape,
83 errors::InvalidArgument("Unable to broadcast tensor of shape ",
84 input_shape, " to tensor of shape ",
85 output_shape));
86
87 // Handle empty case.
88 if (output_shape.num_elements() == 0) {
89 return;
90 }
91
92 functor::BroadcastTo<Device, T>()(device, ctx, *output_tensor, output_shape,
93 input_tensor, input_shape, bcast);
94 }
95};
96
97// As tensor::MakeShape is able to handle both DT_INT32 and DT_INT64,
98// no need to have TypeConstraint for `Tidx`
99#define REGISTER_KERNEL(type) \
100 REGISTER_KERNEL_BUILDER( \
101 Name("BroadcastTo").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
102 BroadcastToOp<CPUDevice, type>);
103
104TF_CALL_ALL_TYPES(REGISTER_KERNEL);
105#undef REGISTER_KERNEL
106
107#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
108 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
109
110namespace functor {
111#define DECLARE_GPU_TEMPLATE(Type) \
112 template <> \
113 void BroadcastTo<GPUDevice, Type>::operator()( \
114 const GPUDevice& d, OpKernelContext* ctx, Tensor& output, \
115 const TensorShape& output_shape, const Tensor& input, \
116 const TensorShape& input_shape, const BCast& bcast) const; \
117 extern template struct BroadcastTo<GPUDevice, Type>;
118
119TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_TEMPLATE);
120TF_CALL_int64(DECLARE_GPU_TEMPLATE);
121#undef DECLARE_GPU_KERNEL
122} // namespace functor
123
124#define REGISTER_KERNEL(type) \
125 REGISTER_KERNEL_BUILDER(Name("BroadcastTo") \
126 .Device(DEVICE_GPU) \
127 .TypeConstraint<type>("T") \
128 .HostMemory("shape"), \
129 BroadcastToOp<GPUDevice, type>);
130
131TF_CALL_GPU_ALL_TYPES(REGISTER_KERNEL);
132TF_CALL_int64(REGISTER_KERNEL);
133#undef REGISTER_KERNEL
134
135// A special GPU kernel for int32.
136// TODO(b/25387198): Also enable int32 in device memory. This kernel
137// registration requires all int32 inputs and outputs to be in host memory.
138REGISTER_KERNEL_BUILDER(Name("BroadcastTo")
139 .Device(DEVICE_GPU)
140 .TypeConstraint<int32>("T")
141 .HostMemory("input")
142 .HostMemory("shape")
143 .HostMemory("output"),
144 BroadcastToOp<CPUDevice, int32>);
145#endif
146
147} // namespace tensorflow
148