1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #define EIGEN_USE_THREADS |
17 | |
18 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
19 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
20 | #define EIGEN_USE_GPU |
21 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
22 | |
23 | #include "tensorflow/core/kernels/broadcast_to_op.h" |
24 | |
25 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/framework/tensor_util.h" |
30 | #include "tensorflow/core/framework/types.h" |
31 | #include "tensorflow/core/util/bcast.h" |
32 | |
33 | namespace tensorflow { |
34 | |
35 | typedef Eigen::ThreadPoolDevice CPUDevice; |
36 | typedef Eigen::GpuDevice GPUDevice; |
37 | |
38 | template <typename Device, typename T> |
39 | class BroadcastToOp : public OpKernel { |
40 | public: |
41 | explicit BroadcastToOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
42 | |
43 | void Compute(OpKernelContext* ctx) override { |
44 | const Tensor& input_tensor = ctx->input(0); |
45 | const TensorShape& input_shape = input_tensor.shape(); |
46 | |
47 | const Tensor& shape_tensor = ctx->input(1); |
48 | |
49 | TensorShape output_shape; |
50 | OP_REQUIRES_OK(ctx, tensor::MakeShape(shape_tensor, &output_shape)); |
51 | |
52 | // Handle copy. |
53 | if (output_shape == input_shape) { |
54 | ctx->set_output(0, input_tensor); |
55 | return; |
56 | } |
57 | |
58 | OP_REQUIRES(ctx, input_shape.dims() <= output_shape.dims(), |
59 | errors::InvalidArgument( |
60 | "Rank of input (" , input_shape.dims(), |
61 | ") must be no greater than rank of output shape (" , |
62 | output_shape.dims(), ")." )); |
63 | |
64 | Tensor* output_tensor = nullptr; |
65 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &output_tensor)); |
66 | |
67 | // Handle broadcast from Scalar. |
68 | const Device& device = ctx->eigen_device<Device>(); |
69 | if (input_shape.dims() == 0) { |
70 | functor::FillFunctor<Device, T>()(device, output_tensor->flat<T>(), |
71 | input_tensor.scalar<T>()); |
72 | return; |
73 | } |
74 | |
75 | // Check whether the broadcast is valid. |
76 | BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(output_shape), |
77 | /*fewer_dims_optimization=*/true); |
78 | OP_REQUIRES(ctx, bcast.IsValid(), |
79 | errors::InvalidArgument( |
80 | "Incompatible shapes: " , input_shape.DebugString(), " vs. " , |
81 | output_shape.DebugString())); |
82 | OP_REQUIRES(ctx, BCast::ToShape(bcast.output_shape()) == output_shape, |
83 | errors::InvalidArgument("Unable to broadcast tensor of shape " , |
84 | input_shape, " to tensor of shape " , |
85 | output_shape)); |
86 | |
87 | // Handle empty case. |
88 | if (output_shape.num_elements() == 0) { |
89 | return; |
90 | } |
91 | |
92 | functor::BroadcastTo<Device, T>()(device, ctx, *output_tensor, output_shape, |
93 | input_tensor, input_shape, bcast); |
94 | } |
95 | }; |
96 | |
97 | // As tensor::MakeShape is able to handle both DT_INT32 and DT_INT64, |
98 | // no need to have TypeConstraint for `Tidx` |
99 | #define REGISTER_KERNEL(type) \ |
100 | REGISTER_KERNEL_BUILDER( \ |
101 | Name("BroadcastTo").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
102 | BroadcastToOp<CPUDevice, type>); |
103 | |
104 | TF_CALL_ALL_TYPES(REGISTER_KERNEL); |
105 | #undef REGISTER_KERNEL |
106 | |
107 | #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \ |
108 | (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM) |
109 | |
110 | namespace functor { |
111 | #define DECLARE_GPU_TEMPLATE(Type) \ |
112 | template <> \ |
113 | void BroadcastTo<GPUDevice, Type>::operator()( \ |
114 | const GPUDevice& d, OpKernelContext* ctx, Tensor& output, \ |
115 | const TensorShape& output_shape, const Tensor& input, \ |
116 | const TensorShape& input_shape, const BCast& bcast) const; \ |
117 | extern template struct BroadcastTo<GPUDevice, Type>; |
118 | |
119 | TF_CALL_GPU_ALL_TYPES(DECLARE_GPU_TEMPLATE); |
120 | TF_CALL_int64(DECLARE_GPU_TEMPLATE); |
121 | #undef DECLARE_GPU_KERNEL |
122 | } // namespace functor |
123 | |
124 | #define REGISTER_KERNEL(type) \ |
125 | REGISTER_KERNEL_BUILDER(Name("BroadcastTo") \ |
126 | .Device(DEVICE_GPU) \ |
127 | .TypeConstraint<type>("T") \ |
128 | .HostMemory("shape"), \ |
129 | BroadcastToOp<GPUDevice, type>); |
130 | |
131 | TF_CALL_GPU_ALL_TYPES(REGISTER_KERNEL); |
132 | TF_CALL_int64(REGISTER_KERNEL); |
133 | #undef REGISTER_KERNEL |
134 | |
135 | // A special GPU kernel for int32. |
136 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
137 | // registration requires all int32 inputs and outputs to be in host memory. |
138 | REGISTER_KERNEL_BUILDER(Name("BroadcastTo" ) |
139 | .Device(DEVICE_GPU) |
140 | .TypeConstraint<int32>("T" ) |
141 | .HostMemory("input" ) |
142 | .HostMemory("shape" ) |
143 | .HostMemory("output" ), |
144 | BroadcastToOp<CPUDevice, int32>); |
145 | #endif |
146 | |
147 | } // namespace tensorflow |
148 | |