1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/kernels/aggregate_ops.h"
21
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/op_requires.h"
24#include "tensorflow/core/framework/register_types.h"
25#include "tensorflow/core/framework/tensor.h"
26#include "tensorflow/core/framework/variant.h"
27#include "tensorflow/core/framework/variant_op_registry.h"
28#include "tensorflow/core/kernels/aggregate_ops_cpu.h"
29#include "tensorflow/core/kernels/variant_ops_util.h"
30#include "tensorflow/core/lib/gtl/inlined_vector.h"
31
32namespace tensorflow {
33
34typedef Eigen::ThreadPoolDevice CPUDevice;
35typedef Eigen::GpuDevice GPUDevice;
36
37template <typename Device, typename T>
38class AddNOp : public OpKernel {
39 public:
40 explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
41
42 void Compute(OpKernelContext* ctx) override {
43 if (!ctx->ValidateInputsAreSameShape(this)) return;
44
45 const Tensor& input0 = ctx->input(0);
46 const int num = ctx->num_inputs();
47
48 if (num == 1) {
49 ctx->set_output(0, input0);
50 return;
51 }
52
53 // Try to forward and accumulate the result in one of the input buffers.
54 int reused_input = -1;
55 gtl::InlinedVector<int, 8> input_indices(num);
56 std::iota(input_indices.begin(), input_indices.end(), 0);
57 Tensor* output = nullptr;
58 for (int input_idx = 0; input_idx < num; ++input_idx) {
59 if (ctx->forward_input_to_output_with_shape(input_idx, 0, input0.shape(),
60 &output)) {
61 reused_input = input_idx;
62 break;
63 }
64 }
65 if (reused_input == -1) {
66 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input0.shape(), &output));
67 } else if (reused_input > 0) {
68 // Move the forwarded buffer to the front so we don't double count
69 // anything if there are more than 8 inputs.
70 input_indices[0] = reused_input;
71 input_indices[reused_input] = 0;
72 }
73 auto To = output->flat<T>();
74
75#define I(IDX) ctx->input(input_indices[IDX]).template flat<T>()
76
77#if defined(__ANDROID_TYPES_SLIM__)
78 // On Android by default,we only support additions of two arguments, so we
79 // can reduce the number of template instantiations.
80 OP_REQUIRES(ctx, num == 2,
81 errors::InvalidArgument("Only additions of two arguments "
82 "supported. Num inputs: ",
83 num));
84 functor::Add2Functor<Device, T> functor2;
85 functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
86#else
87 static const int kWidth = 8;
88 int r = num % kWidth;
89
90 switch (r) {
91 case 2: {
92 functor::Add2Functor<Device, T> functor2;
93 functor2(ctx->template eigen_device<Device>(), To, I(0), I(1));
94 break;
95 }
96 case 3: {
97 functor::Add3Functor<Device, T> functor3;
98 functor3(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2));
99 break;
100 }
101 case 4: {
102 functor::Add4Functor<Device, T> functor4;
103 functor4(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
104 I(3));
105 break;
106 }
107 case 5: {
108 functor::Add5Functor<Device, T> functor5;
109 functor5(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
110 I(3), I(4));
111 break;
112 }
113 case 6: {
114 functor::Add6Functor<Device, T> functor6;
115 functor6(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
116 I(3), I(4), I(5));
117 break;
118 }
119 case 7: {
120 functor::Add7Functor<Device, T> functor7;
121 functor7(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
122 I(3), I(4), I(5), I(6));
123 break;
124 }
125 case 0: {
126 functor::Add8Functor<Device, T> functor8;
127 functor8(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
128 I(3), I(4), I(5), I(6), I(7));
129 r = 8;
130 break;
131 }
132 case 1: {
133 functor::Add9Functor<Device, T> functor9;
134 functor9(ctx->template eigen_device<Device>(), To, I(0), I(1), I(2),
135 I(3), I(4), I(5), I(6), I(7), I(8));
136 r = 9;
137 break;
138 }
139 }
140
141 for (; r < num; r += kWidth) {
142 functor::Add8pFunctor<Device, T> functor8p;
143 functor8p(ctx->template eigen_device<Device>(), To, I(r), I(r + 1),
144 I(r + 2), I(r + 3), I(r + 4), I(r + 5), I(r + 6), I(r + 7));
145 }
146#endif // defined(__ANDROID_TYPES_SLIM__)
147
148#undef I
149 }
150};
151
152template <typename Device>
153class AddNOp<Device, Variant> : public OpKernel {
154 public:
155 explicit AddNOp(OpKernelConstruction* context) : OpKernel(context) {}
156
157 void Compute(OpKernelContext* ctx) override {
158 auto binary_add = [](OpKernelContext* cc_ctx, const Variant& a,
159 const Variant& b, Variant* out) {
160 return BinaryOpVariants<Device>(cc_ctx, ADD_VARIANT_BINARY_OP, a, b, out);
161 };
162 AddNVariant(ctx, binary_add);
163 }
164
165 private:
166 // AddVariantTo efficiently performs:
167 // temp[lhs_ix] <- array(lhs_ix) + array(rhs_ix)
168 // where array(ix) := (temp_filled[ix]
169 // ? temp[ix]
170 // : ctx->input(ix).scalar<Variant>()())
171 // This reduces (possibly expensive) copying of Variants from
172 // the inputs into temp at the lowest levels of the summation tree.
173 static inline Status AddVariantTo(OpKernelContext* ctx, const int lhs_ix,
174 const int rhs_ix,
175 gtl::InlinedVector<Variant, 4>* temp,
176 gtl::InlinedVector<bool, 4>* temp_filled) {
177 Variant tmp;
178 if (temp_filled->at(lhs_ix)) tmp = std::move(temp->at(lhs_ix));
179 const Variant& a = temp_filled->at(lhs_ix)
180 ? tmp
181 : ctx->input(lhs_ix).template scalar<Variant>()();
182 const Variant& b = temp_filled->at(rhs_ix)
183 ? temp->at(rhs_ix)
184 : ctx->input(rhs_ix).template scalar<Variant>()();
185 Variant* c = &temp->at(lhs_ix);
186 TF_RETURN_IF_ERROR(
187 BinaryOpVariants<Device>(ctx, ADD_VARIANT_BINARY_OP, a, b, c));
188 temp_filled->at(lhs_ix) = true;
189 return OkStatus();
190 }
191};
192
193#define REGISTER_ADDN(type, dev) \
194 REGISTER_KERNEL_BUILDER( \
195 Name("AddN").Device(DEVICE_##dev).TypeConstraint<type>("T"), \
196 AddNOp<dev##Device, type>)
197
198#define REGISTER_ADDN_CPU(type) REGISTER_ADDN(type, CPU)
199
200TF_CALL_NUMBER_TYPES(REGISTER_ADDN_CPU);
201REGISTER_ADDN_CPU(Variant);
202
203#undef REGISTER_ADDN_CPU
204
205#if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
206 (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
207#define REGISTER_ADDN_GPU(type) REGISTER_ADDN(type, GPU)
208TF_CALL_int64(REGISTER_ADDN_GPU);
209TF_CALL_uint32(REGISTER_ADDN_GPU);
210TF_CALL_variant(REGISTER_ADDN_GPU);
211TF_CALL_GPU_NUMBER_TYPES(REGISTER_ADDN_GPU);
212TF_CALL_COMPLEX_TYPES(REGISTER_ADDN_GPU);
213#undef REGISTER_ADDN_GPU
214
215#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
216
217// A special DEVICE_DEFAULT kernel for int32.
218// TODO(b/25387198): Also enable int32 in device memory. This kernel
219// registration requires all int32 inputs and outputs to be in host memory.
220REGISTER_KERNEL_BUILDER(Name("AddN")
221 .Device(DEVICE_DEFAULT)
222 .TypeConstraint<int32>("T")
223 .HostMemory("inputs")
224 .HostMemory("sum"),
225 AddNOp<CPUDevice, int32>);
226
227#undef REGISTER_ADDN
228
229} // namespace tensorflow
230