1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/kernels/cwise_ops.h"
23#include "tensorflow/core/kernels/cwise_ops_common.h"
24#include "tensorflow/core/kernels/relu_op_functor.h"
25
26namespace tensorflow {
27
28template <typename T>
29class UnaryOpsComposition; // forward declare kernel
30
31template <typename T>
32struct UnaryOpsCompositionSupport;
33
34template <typename T>
35struct UnaryOpsCompositionBase {
36 using InputBuffer = typename TTypes<T>::ConstFlat;
37 using OutputBuffer = typename TTypes<T>::Flat;
38
39 using ComputeFn = void (*)(const InputBuffer&, OutputBuffer*);
40
41 struct ComputeFnRegistration {
42 ComputeFn compute_fn;
43 int cost;
44 };
45
46 bool HasComputeFn(const string& name) {
47 return compute_fns.find(name) != compute_fns.end();
48 }
49
50 protected:
51 void RegisterComputeFn(const string& name, ComputeFn compute_fn, int cost) {
52 VLOG(5) << "Register compute fn: name=" << name << " cost=" << cost;
53 compute_fns[name] = {compute_fn, cost};
54 }
55
56 private:
57 friend class UnaryOpsComposition<T>;
58
59 Status ExportComputeFns(const std::vector<string>& op_names,
60 std::vector<ComputeFn>* fns, int* cost) {
61 for (const string& op_name : op_names) {
62 auto it = compute_fns.find(op_name);
63 if (it == compute_fns.end())
64 return errors::InvalidArgument(
65 "Do not have a compute function registered for op: ", op_name);
66
67 const ComputeFnRegistration& reg = it->second;
68 fns->push_back(reg.compute_fn);
69 *cost += reg.cost;
70 }
71
72 return OkStatus();
73 }
74
75 std::unordered_map<string, ComputeFnRegistration> compute_fns;
76};
77
78template <typename T>
79class UnaryOpsComposition : public OpKernel {
80 public:
81 using Kernel = UnaryOpsComposition<T>;
82
83 using Scalar = T;
84 using Packet = typename Eigen::internal::packet_traits<T>::type;
85
86 using Support = UnaryOpsCompositionSupport<T>;
87
88 using InputBuffer = typename Support::InputBuffer;
89 using OutputBuffer = typename Support::OutputBuffer;
90 using ComputeFn = typename Support::ComputeFn;
91
92 explicit UnaryOpsComposition(OpKernelConstruction* context)
93 : OpKernel(context) {
94 OP_REQUIRES_OK(context, context->GetAttr("op_names", &op_names_));
95
96 OP_REQUIRES(context, !op_names_.empty(),
97 errors::InvalidArgument(
98 "Unary op composition must have at least one op"));
99
100 OP_REQUIRES_OK(context,
101 support_.ExportComputeFns(op_names_, &fns_, &cost_));
102
103 VLOG(2) << "Composed unary op: [" << absl::StrJoin(op_names_, ", ")
104 << "]; cost=" << cost_;
105 }
106
107 void Compute(OpKernelContext* ctx) override {
108 const Tensor& in = ctx->input(0);
109 Tensor* out = nullptr;
110 OP_REQUIRES_OK(
111 ctx, ctx->forward_input_or_allocate_output({0}, 0, in.shape(), &out));
112
113 InputBuffer in_flat = in.flat<T>();
114 OutputBuffer out_flat = out->flat<T>();
115
116 const std::size_t num_fns = fns_.size();
117 auto compute_fn = [this, &in_flat, &out_flat, &num_fns](int64_t begin,
118 int64_t end) {
119 int64_t len = end - begin;
120 const InputBuffer in_slice(in_flat.data() + begin, len);
121 const InputBuffer scratch_slice(out_flat.data() + begin, len);
122 OutputBuffer out_slice(out_flat.data() + begin, len);
123
124 fns_[0](in_slice, &out_slice);
125 for (int i = 1; i < num_fns; ++i) {
126 fns_[i](scratch_slice, &out_slice);
127 }
128 };
129
130 const CPUDevice& device = ctx->eigen_device<CPUDevice>();
131 const int kOverheadCycles = static_cast<int>(num_fns) * 10;
132 Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T) * num_fns,
133 /*bytes_stored=*/sizeof(T) * num_fns,
134 kOverheadCycles + cost_);
135 device.parallelFor(in.NumElements(), cost, AlignBlockSize,
136 std::move(compute_fn));
137 }
138
139 private:
140 static constexpr int kPacketSize =
141 Eigen::internal::unpacket_traits<Packet>::size;
142
143 static inline int64_t AlignBlockSize(int64_t block_size) {
144 // Align block size to packet size and account for unrolling in run above.
145 if (block_size >= 16 * kPacketSize) {
146 return (block_size + 4 * kPacketSize - 1) & ~(4 * kPacketSize - 1);
147 }
148 // Aligning to 4 * PacketSize would increase block size by more than 25%.
149 return (block_size + kPacketSize - 1) & ~(kPacketSize - 1);
150 }
151
152 Support support_;
153
154 std::vector<string> op_names_;
155 std::vector<ComputeFn> fns_;
156 int cost_ = 0;
157};
158
159// Register compute functions for UnaryOp functors.
160#define REGISTER_COMPUTE_FN_HELPER(name, functor) \
161 static_assert(std::is_same<functor::in_type, functor::out_type>::value, \
162 "Functor must have same input and output types"); \
163 \
164 static inline void Compute##name(const InputBuffer& in, OutputBuffer* out) { \
165 *out = in.unaryExpr(functor::func()); \
166 } \
167 static inline int Cost##name() { \
168 return Eigen::internal::functor_traits<functor::func>::Cost; \
169 }
170
171// Register compute function for the Relu/Relu6/Elu/Selu.
172#define REGISTER_RELU_HELPER() \
173 template <typename T> \
174 using functor_traits = Eigen::internal::functor_traits<T>; \
175 \
176 static inline void ComputeRelu(const InputBuffer& in, OutputBuffer* out) { \
177 auto relu = functor::Relu<Eigen::DefaultDevice, T>(); \
178 relu(Eigen::DefaultDevice(), in, *out); \
179 } \
180 \
181 static inline int CostRelu() { \
182 return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost; \
183 } \
184 \
185 static inline void ComputeRelu6(const InputBuffer& in, OutputBuffer* out) { \
186 auto relu6 = functor::Relu6<Eigen::DefaultDevice, T>(); \
187 relu6(Eigen::DefaultDevice(), in, *out); \
188 } \
189 \
190 static inline int CostRelu6() { \
191 return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost + \
192 functor_traits<Eigen::internal::scalar_min_op<T>>::Cost; \
193 } \
194 static inline void ComputeElu(const InputBuffer& in, OutputBuffer* out) { \
195 auto elu = functor::Elu<Eigen::DefaultDevice, T>(); \
196 elu(Eigen::DefaultDevice(), in, *out); \
197 } \
198 \
199 static inline int CostElu() { \
200 return functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \
201 Eigen::NumTraits<T>::MulCost; \
202 } \
203 static inline void ComputeSelu(const InputBuffer& in, OutputBuffer* out) { \
204 auto selu = functor::Selu<Eigen::DefaultDevice, T>(); \
205 selu(Eigen::DefaultDevice(), in, *out); \
206 } \
207 \
208 static inline int CostSelu() { \
209 return 2 * (functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \
210 Eigen::NumTraits<T>::MulCost); \
211 }
212
213#define REGISTER_COMPUTE_FN(func) \
214 RegisterComputeFn(#func, Compute##func, Cost##func());
215
216template <>
217struct UnaryOpsCompositionSupport<float> : UnaryOpsCompositionBase<float> {
218 using T = float;
219
220 UnaryOpsCompositionSupport() {
221 // UnaryOp functors.
222 REGISTER_COMPUTE_FN(Abs);
223 REGISTER_COMPUTE_FN(Acos);
224 REGISTER_COMPUTE_FN(Acosh);
225 REGISTER_COMPUTE_FN(Asin);
226 REGISTER_COMPUTE_FN(Asinh);
227 REGISTER_COMPUTE_FN(Atan);
228 REGISTER_COMPUTE_FN(Atanh);
229 REGISTER_COMPUTE_FN(Ceil);
230 REGISTER_COMPUTE_FN(Cos);
231 REGISTER_COMPUTE_FN(Cosh);
232 REGISTER_COMPUTE_FN(Expm1);
233 REGISTER_COMPUTE_FN(Exp);
234 REGISTER_COMPUTE_FN(Floor);
235 REGISTER_COMPUTE_FN(Inv);
236 REGISTER_COMPUTE_FN(Log);
237 REGISTER_COMPUTE_FN(Log1p);
238 REGISTER_COMPUTE_FN(Neg);
239 REGISTER_COMPUTE_FN(Reciprocal);
240 REGISTER_COMPUTE_FN(Rint);
241 REGISTER_COMPUTE_FN(Round);
242 REGISTER_COMPUTE_FN(Rsqrt);
243 REGISTER_COMPUTE_FN(Sigmoid);
244 REGISTER_COMPUTE_FN(Sin);
245 REGISTER_COMPUTE_FN(Sinh);
246 REGISTER_COMPUTE_FN(Sqrt);
247 REGISTER_COMPUTE_FN(Square);
248 REGISTER_COMPUTE_FN(Tan);
249 REGISTER_COMPUTE_FN(Tanh);
250
251 // Additional compute functions not defined via UnaryOp functors.
252 REGISTER_COMPUTE_FN(Elu);
253 REGISTER_COMPUTE_FN(Relu);
254 REGISTER_COMPUTE_FN(Relu6);
255 REGISTER_COMPUTE_FN(Selu);
256 }
257
258 REGISTER_RELU_HELPER();
259
260 // clang-format off
261 REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
262 REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>);
263 REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>);
264 REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>);
265 REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>);
266 REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>);
267 REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>);
268 REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
269 REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
270 REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>);
271 REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
272 REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
273 REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
274 REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
275 REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
276 REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
277 REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
278 REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
279 REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>);
280 REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
281 REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
282 REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
283 REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
284 REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>);
285 REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
286 REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
287 REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>);
288 REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
289 // clang-format on
290};
291
292template <>
293struct UnaryOpsCompositionSupport<Eigen::half>
294 : UnaryOpsCompositionBase<Eigen::half> {
295 using T = Eigen::half;
296
297 UnaryOpsCompositionSupport() {
298 REGISTER_COMPUTE_FN(Abs);
299 REGISTER_COMPUTE_FN(Ceil);
300 REGISTER_COMPUTE_FN(Cos);
301 REGISTER_COMPUTE_FN(Expm1);
302 REGISTER_COMPUTE_FN(Exp);
303 REGISTER_COMPUTE_FN(Floor);
304 REGISTER_COMPUTE_FN(Inv);
305 REGISTER_COMPUTE_FN(Log);
306 REGISTER_COMPUTE_FN(Log1p);
307 REGISTER_COMPUTE_FN(Neg);
308 REGISTER_COMPUTE_FN(Reciprocal);
309 REGISTER_COMPUTE_FN(Round);
310 REGISTER_COMPUTE_FN(Rsqrt);
311 REGISTER_COMPUTE_FN(Sigmoid);
312 REGISTER_COMPUTE_FN(Sin);
313 REGISTER_COMPUTE_FN(Sqrt);
314 REGISTER_COMPUTE_FN(Square);
315 REGISTER_COMPUTE_FN(Tanh);
316 // Additional compute functions not defined via UnaryOp functors.
317 REGISTER_COMPUTE_FN(Elu);
318 REGISTER_COMPUTE_FN(Relu);
319 REGISTER_COMPUTE_FN(Relu6);
320 REGISTER_COMPUTE_FN(Selu);
321 }
322
323 REGISTER_RELU_HELPER();
324
325 // clang-format off
326 REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
327 REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
328 REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
329 REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
330 REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
331 REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
332 REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
333 REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
334 REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
335 REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
336 REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
337 REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
338 REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
339 REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
340 REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
341 REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
342 REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
343 REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
344 // clang-format on
345};
346
347template <>
348struct UnaryOpsCompositionSupport<double> : UnaryOpsCompositionBase<double> {
349 using T = double;
350
351 UnaryOpsCompositionSupport() {
352 REGISTER_COMPUTE_FN(Abs);
353 REGISTER_COMPUTE_FN(Acos);
354 REGISTER_COMPUTE_FN(Acosh);
355 REGISTER_COMPUTE_FN(Asin);
356 REGISTER_COMPUTE_FN(Asinh);
357 REGISTER_COMPUTE_FN(Atan);
358 REGISTER_COMPUTE_FN(Atanh);
359 REGISTER_COMPUTE_FN(Ceil);
360 REGISTER_COMPUTE_FN(Cos);
361 REGISTER_COMPUTE_FN(Cosh);
362 REGISTER_COMPUTE_FN(Expm1);
363 REGISTER_COMPUTE_FN(Exp);
364 REGISTER_COMPUTE_FN(Floor);
365 REGISTER_COMPUTE_FN(Inv);
366 REGISTER_COMPUTE_FN(Log);
367 REGISTER_COMPUTE_FN(Log1p);
368 REGISTER_COMPUTE_FN(Neg);
369 REGISTER_COMPUTE_FN(Reciprocal);
370 REGISTER_COMPUTE_FN(Rint);
371 REGISTER_COMPUTE_FN(Round);
372 REGISTER_COMPUTE_FN(Rsqrt);
373 REGISTER_COMPUTE_FN(Sigmoid);
374 REGISTER_COMPUTE_FN(Sin);
375 REGISTER_COMPUTE_FN(Sinh);
376 REGISTER_COMPUTE_FN(Sqrt);
377 REGISTER_COMPUTE_FN(Square);
378 REGISTER_COMPUTE_FN(Tan);
379 REGISTER_COMPUTE_FN(Tanh);
380 // Additional compute functions not defined via UnaryOp functors.
381 REGISTER_COMPUTE_FN(Elu);
382 REGISTER_COMPUTE_FN(Relu);
383 REGISTER_COMPUTE_FN(Relu6);
384 REGISTER_COMPUTE_FN(Selu);
385 }
386
387 REGISTER_RELU_HELPER();
388
389 // clang-format off
390 REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>);
391 REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>);
392 REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>);
393 REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>);
394 REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>);
395 REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>);
396 REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>);
397 REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>);
398 REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>);
399 REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>);
400 REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>);
401 REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>);
402 REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>);
403 REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>);
404 REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>);
405 REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>);
406 REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>);
407 REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>);
408 REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>);
409 REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>);
410 REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>);
411 REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>);
412 REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>);
413 REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>);
414 REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>);
415 REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>);
416 REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>);
417 REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>);
418 // clang-format on
419};
420
421// Register the CPU kernels.
422#define REGISTER_CPU(T) \
423 REGISTER_KERNEL_BUILDER( \
424 Name("_UnaryOpsComposition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
425 UnaryOpsComposition<T>);
426
427REGISTER_CPU(float);
428REGISTER_CPU(Eigen::half);
429REGISTER_CPU(double);
430
431#undef REGISTER_CPU
432
433} // namespace tensorflow
434