1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/kernels/cwise_ops.h" |
23 | #include "tensorflow/core/kernels/cwise_ops_common.h" |
24 | #include "tensorflow/core/kernels/relu_op_functor.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | template <typename T> |
29 | class UnaryOpsComposition; // forward declare kernel |
30 | |
31 | template <typename T> |
32 | struct UnaryOpsCompositionSupport; |
33 | |
34 | template <typename T> |
35 | struct UnaryOpsCompositionBase { |
36 | using InputBuffer = typename TTypes<T>::ConstFlat; |
37 | using OutputBuffer = typename TTypes<T>::Flat; |
38 | |
39 | using ComputeFn = void (*)(const InputBuffer&, OutputBuffer*); |
40 | |
41 | struct ComputeFnRegistration { |
42 | ComputeFn compute_fn; |
43 | int cost; |
44 | }; |
45 | |
46 | bool HasComputeFn(const string& name) { |
47 | return compute_fns.find(name) != compute_fns.end(); |
48 | } |
49 | |
50 | protected: |
51 | void RegisterComputeFn(const string& name, ComputeFn compute_fn, int cost) { |
52 | VLOG(5) << "Register compute fn: name=" << name << " cost=" << cost; |
53 | compute_fns[name] = {compute_fn, cost}; |
54 | } |
55 | |
56 | private: |
57 | friend class UnaryOpsComposition<T>; |
58 | |
59 | Status ExportComputeFns(const std::vector<string>& op_names, |
60 | std::vector<ComputeFn>* fns, int* cost) { |
61 | for (const string& op_name : op_names) { |
62 | auto it = compute_fns.find(op_name); |
63 | if (it == compute_fns.end()) |
64 | return errors::InvalidArgument( |
65 | "Do not have a compute function registered for op: " , op_name); |
66 | |
67 | const ComputeFnRegistration& reg = it->second; |
68 | fns->push_back(reg.compute_fn); |
69 | *cost += reg.cost; |
70 | } |
71 | |
72 | return OkStatus(); |
73 | } |
74 | |
75 | std::unordered_map<string, ComputeFnRegistration> compute_fns; |
76 | }; |
77 | |
78 | template <typename T> |
79 | class UnaryOpsComposition : public OpKernel { |
80 | public: |
81 | using Kernel = UnaryOpsComposition<T>; |
82 | |
83 | using Scalar = T; |
84 | using Packet = typename Eigen::internal::packet_traits<T>::type; |
85 | |
86 | using Support = UnaryOpsCompositionSupport<T>; |
87 | |
88 | using InputBuffer = typename Support::InputBuffer; |
89 | using OutputBuffer = typename Support::OutputBuffer; |
90 | using ComputeFn = typename Support::ComputeFn; |
91 | |
92 | explicit UnaryOpsComposition(OpKernelConstruction* context) |
93 | : OpKernel(context) { |
94 | OP_REQUIRES_OK(context, context->GetAttr("op_names" , &op_names_)); |
95 | |
96 | OP_REQUIRES(context, !op_names_.empty(), |
97 | errors::InvalidArgument( |
98 | "Unary op composition must have at least one op" )); |
99 | |
100 | OP_REQUIRES_OK(context, |
101 | support_.ExportComputeFns(op_names_, &fns_, &cost_)); |
102 | |
103 | VLOG(2) << "Composed unary op: [" << absl::StrJoin(op_names_, ", " ) |
104 | << "]; cost=" << cost_; |
105 | } |
106 | |
107 | void Compute(OpKernelContext* ctx) override { |
108 | const Tensor& in = ctx->input(0); |
109 | Tensor* out = nullptr; |
110 | OP_REQUIRES_OK( |
111 | ctx, ctx->forward_input_or_allocate_output({0}, 0, in.shape(), &out)); |
112 | |
113 | InputBuffer in_flat = in.flat<T>(); |
114 | OutputBuffer out_flat = out->flat<T>(); |
115 | |
116 | const std::size_t num_fns = fns_.size(); |
117 | auto compute_fn = [this, &in_flat, &out_flat, &num_fns](int64_t begin, |
118 | int64_t end) { |
119 | int64_t len = end - begin; |
120 | const InputBuffer in_slice(in_flat.data() + begin, len); |
121 | const InputBuffer scratch_slice(out_flat.data() + begin, len); |
122 | OutputBuffer out_slice(out_flat.data() + begin, len); |
123 | |
124 | fns_[0](in_slice, &out_slice); |
125 | for (int i = 1; i < num_fns; ++i) { |
126 | fns_[i](scratch_slice, &out_slice); |
127 | } |
128 | }; |
129 | |
130 | const CPUDevice& device = ctx->eigen_device<CPUDevice>(); |
131 | const int kOverheadCycles = static_cast<int>(num_fns) * 10; |
132 | Eigen::TensorOpCost cost(/*bytes_loaded=*/sizeof(T) * num_fns, |
133 | /*bytes_stored=*/sizeof(T) * num_fns, |
134 | kOverheadCycles + cost_); |
135 | device.parallelFor(in.NumElements(), cost, AlignBlockSize, |
136 | std::move(compute_fn)); |
137 | } |
138 | |
139 | private: |
140 | static constexpr int kPacketSize = |
141 | Eigen::internal::unpacket_traits<Packet>::size; |
142 | |
143 | static inline int64_t AlignBlockSize(int64_t block_size) { |
144 | // Align block size to packet size and account for unrolling in run above. |
145 | if (block_size >= 16 * kPacketSize) { |
146 | return (block_size + 4 * kPacketSize - 1) & ~(4 * kPacketSize - 1); |
147 | } |
148 | // Aligning to 4 * PacketSize would increase block size by more than 25%. |
149 | return (block_size + kPacketSize - 1) & ~(kPacketSize - 1); |
150 | } |
151 | |
152 | Support support_; |
153 | |
154 | std::vector<string> op_names_; |
155 | std::vector<ComputeFn> fns_; |
156 | int cost_ = 0; |
157 | }; |
158 | |
159 | // Register compute functions for UnaryOp functors. |
160 | #define REGISTER_COMPUTE_FN_HELPER(name, functor) \ |
161 | static_assert(std::is_same<functor::in_type, functor::out_type>::value, \ |
162 | "Functor must have same input and output types"); \ |
163 | \ |
164 | static inline void Compute##name(const InputBuffer& in, OutputBuffer* out) { \ |
165 | *out = in.unaryExpr(functor::func()); \ |
166 | } \ |
167 | static inline int Cost##name() { \ |
168 | return Eigen::internal::functor_traits<functor::func>::Cost; \ |
169 | } |
170 | |
171 | // Register compute function for the Relu/Relu6/Elu/Selu. |
172 | #define REGISTER_RELU_HELPER() \ |
173 | template <typename T> \ |
174 | using functor_traits = Eigen::internal::functor_traits<T>; \ |
175 | \ |
176 | static inline void ComputeRelu(const InputBuffer& in, OutputBuffer* out) { \ |
177 | auto relu = functor::Relu<Eigen::DefaultDevice, T>(); \ |
178 | relu(Eigen::DefaultDevice(), in, *out); \ |
179 | } \ |
180 | \ |
181 | static inline int CostRelu() { \ |
182 | return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost; \ |
183 | } \ |
184 | \ |
185 | static inline void ComputeRelu6(const InputBuffer& in, OutputBuffer* out) { \ |
186 | auto relu6 = functor::Relu6<Eigen::DefaultDevice, T>(); \ |
187 | relu6(Eigen::DefaultDevice(), in, *out); \ |
188 | } \ |
189 | \ |
190 | static inline int CostRelu6() { \ |
191 | return functor_traits<Eigen::internal::scalar_max_op<T>>::Cost + \ |
192 | functor_traits<Eigen::internal::scalar_min_op<T>>::Cost; \ |
193 | } \ |
194 | static inline void ComputeElu(const InputBuffer& in, OutputBuffer* out) { \ |
195 | auto elu = functor::Elu<Eigen::DefaultDevice, T>(); \ |
196 | elu(Eigen::DefaultDevice(), in, *out); \ |
197 | } \ |
198 | \ |
199 | static inline int CostElu() { \ |
200 | return functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \ |
201 | Eigen::NumTraits<T>::MulCost; \ |
202 | } \ |
203 | static inline void ComputeSelu(const InputBuffer& in, OutputBuffer* out) { \ |
204 | auto selu = functor::Selu<Eigen::DefaultDevice, T>(); \ |
205 | selu(Eigen::DefaultDevice(), in, *out); \ |
206 | } \ |
207 | \ |
208 | static inline int CostSelu() { \ |
209 | return 2 * (functor_traits<Eigen::internal::scalar_exp_op<T>>::Cost + \ |
210 | Eigen::NumTraits<T>::MulCost); \ |
211 | } |
212 | |
213 | #define REGISTER_COMPUTE_FN(func) \ |
214 | RegisterComputeFn(#func, Compute##func, Cost##func()); |
215 | |
216 | template <> |
217 | struct UnaryOpsCompositionSupport<float> : UnaryOpsCompositionBase<float> { |
218 | using T = float; |
219 | |
220 | UnaryOpsCompositionSupport() { |
221 | // UnaryOp functors. |
222 | REGISTER_COMPUTE_FN(Abs); |
223 | REGISTER_COMPUTE_FN(Acos); |
224 | REGISTER_COMPUTE_FN(Acosh); |
225 | REGISTER_COMPUTE_FN(Asin); |
226 | REGISTER_COMPUTE_FN(Asinh); |
227 | REGISTER_COMPUTE_FN(Atan); |
228 | REGISTER_COMPUTE_FN(Atanh); |
229 | REGISTER_COMPUTE_FN(Ceil); |
230 | REGISTER_COMPUTE_FN(Cos); |
231 | REGISTER_COMPUTE_FN(Cosh); |
232 | REGISTER_COMPUTE_FN(Expm1); |
233 | REGISTER_COMPUTE_FN(Exp); |
234 | REGISTER_COMPUTE_FN(Floor); |
235 | REGISTER_COMPUTE_FN(Inv); |
236 | REGISTER_COMPUTE_FN(Log); |
237 | REGISTER_COMPUTE_FN(Log1p); |
238 | REGISTER_COMPUTE_FN(Neg); |
239 | REGISTER_COMPUTE_FN(Reciprocal); |
240 | REGISTER_COMPUTE_FN(Rint); |
241 | REGISTER_COMPUTE_FN(Round); |
242 | REGISTER_COMPUTE_FN(Rsqrt); |
243 | REGISTER_COMPUTE_FN(Sigmoid); |
244 | REGISTER_COMPUTE_FN(Sin); |
245 | REGISTER_COMPUTE_FN(Sinh); |
246 | REGISTER_COMPUTE_FN(Sqrt); |
247 | REGISTER_COMPUTE_FN(Square); |
248 | REGISTER_COMPUTE_FN(Tan); |
249 | REGISTER_COMPUTE_FN(Tanh); |
250 | |
251 | // Additional compute functions not defined via UnaryOp functors. |
252 | REGISTER_COMPUTE_FN(Elu); |
253 | REGISTER_COMPUTE_FN(Relu); |
254 | REGISTER_COMPUTE_FN(Relu6); |
255 | REGISTER_COMPUTE_FN(Selu); |
256 | } |
257 | |
258 | REGISTER_RELU_HELPER(); |
259 | |
260 | // clang-format off |
261 | REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>); |
262 | REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>); |
263 | REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>); |
264 | REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>); |
265 | REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>); |
266 | REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>); |
267 | REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>); |
268 | REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>); |
269 | REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>); |
270 | REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>); |
271 | REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>); |
272 | REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>); |
273 | REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>); |
274 | REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>); |
275 | REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>); |
276 | REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>); |
277 | REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>); |
278 | REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>); |
279 | REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>); |
280 | REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>); |
281 | REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>); |
282 | REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>); |
283 | REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>); |
284 | REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>); |
285 | REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>); |
286 | REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>); |
287 | REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>); |
288 | REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>); |
289 | // clang-format on |
290 | }; |
291 | |
292 | template <> |
293 | struct UnaryOpsCompositionSupport<Eigen::half> |
294 | : UnaryOpsCompositionBase<Eigen::half> { |
295 | using T = Eigen::half; |
296 | |
297 | UnaryOpsCompositionSupport() { |
298 | REGISTER_COMPUTE_FN(Abs); |
299 | REGISTER_COMPUTE_FN(Ceil); |
300 | REGISTER_COMPUTE_FN(Cos); |
301 | REGISTER_COMPUTE_FN(Expm1); |
302 | REGISTER_COMPUTE_FN(Exp); |
303 | REGISTER_COMPUTE_FN(Floor); |
304 | REGISTER_COMPUTE_FN(Inv); |
305 | REGISTER_COMPUTE_FN(Log); |
306 | REGISTER_COMPUTE_FN(Log1p); |
307 | REGISTER_COMPUTE_FN(Neg); |
308 | REGISTER_COMPUTE_FN(Reciprocal); |
309 | REGISTER_COMPUTE_FN(Round); |
310 | REGISTER_COMPUTE_FN(Rsqrt); |
311 | REGISTER_COMPUTE_FN(Sigmoid); |
312 | REGISTER_COMPUTE_FN(Sin); |
313 | REGISTER_COMPUTE_FN(Sqrt); |
314 | REGISTER_COMPUTE_FN(Square); |
315 | REGISTER_COMPUTE_FN(Tanh); |
316 | // Additional compute functions not defined via UnaryOp functors. |
317 | REGISTER_COMPUTE_FN(Elu); |
318 | REGISTER_COMPUTE_FN(Relu); |
319 | REGISTER_COMPUTE_FN(Relu6); |
320 | REGISTER_COMPUTE_FN(Selu); |
321 | } |
322 | |
323 | REGISTER_RELU_HELPER(); |
324 | |
325 | // clang-format off |
326 | REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>); |
327 | REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>); |
328 | REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>); |
329 | REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>); |
330 | REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>); |
331 | REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>); |
332 | REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>); |
333 | REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>); |
334 | REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>); |
335 | REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>); |
336 | REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>); |
337 | REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>); |
338 | REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>); |
339 | REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>); |
340 | REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>); |
341 | REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>); |
342 | REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>); |
343 | REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>); |
344 | // clang-format on |
345 | }; |
346 | |
347 | template <> |
348 | struct UnaryOpsCompositionSupport<double> : UnaryOpsCompositionBase<double> { |
349 | using T = double; |
350 | |
351 | UnaryOpsCompositionSupport() { |
352 | REGISTER_COMPUTE_FN(Abs); |
353 | REGISTER_COMPUTE_FN(Acos); |
354 | REGISTER_COMPUTE_FN(Acosh); |
355 | REGISTER_COMPUTE_FN(Asin); |
356 | REGISTER_COMPUTE_FN(Asinh); |
357 | REGISTER_COMPUTE_FN(Atan); |
358 | REGISTER_COMPUTE_FN(Atanh); |
359 | REGISTER_COMPUTE_FN(Ceil); |
360 | REGISTER_COMPUTE_FN(Cos); |
361 | REGISTER_COMPUTE_FN(Cosh); |
362 | REGISTER_COMPUTE_FN(Expm1); |
363 | REGISTER_COMPUTE_FN(Exp); |
364 | REGISTER_COMPUTE_FN(Floor); |
365 | REGISTER_COMPUTE_FN(Inv); |
366 | REGISTER_COMPUTE_FN(Log); |
367 | REGISTER_COMPUTE_FN(Log1p); |
368 | REGISTER_COMPUTE_FN(Neg); |
369 | REGISTER_COMPUTE_FN(Reciprocal); |
370 | REGISTER_COMPUTE_FN(Rint); |
371 | REGISTER_COMPUTE_FN(Round); |
372 | REGISTER_COMPUTE_FN(Rsqrt); |
373 | REGISTER_COMPUTE_FN(Sigmoid); |
374 | REGISTER_COMPUTE_FN(Sin); |
375 | REGISTER_COMPUTE_FN(Sinh); |
376 | REGISTER_COMPUTE_FN(Sqrt); |
377 | REGISTER_COMPUTE_FN(Square); |
378 | REGISTER_COMPUTE_FN(Tan); |
379 | REGISTER_COMPUTE_FN(Tanh); |
380 | // Additional compute functions not defined via UnaryOp functors. |
381 | REGISTER_COMPUTE_FN(Elu); |
382 | REGISTER_COMPUTE_FN(Relu); |
383 | REGISTER_COMPUTE_FN(Relu6); |
384 | REGISTER_COMPUTE_FN(Selu); |
385 | } |
386 | |
387 | REGISTER_RELU_HELPER(); |
388 | |
389 | // clang-format off |
390 | REGISTER_COMPUTE_FN_HELPER(Abs, functor::abs<T>); |
391 | REGISTER_COMPUTE_FN_HELPER(Acos, functor::acos<T>); |
392 | REGISTER_COMPUTE_FN_HELPER(Acosh, functor::acosh<T>); |
393 | REGISTER_COMPUTE_FN_HELPER(Asin, functor::asin<T>); |
394 | REGISTER_COMPUTE_FN_HELPER(Asinh, functor::asinh<T>); |
395 | REGISTER_COMPUTE_FN_HELPER(Atan, functor::atan<T>); |
396 | REGISTER_COMPUTE_FN_HELPER(Atanh, functor::atanh<T>); |
397 | REGISTER_COMPUTE_FN_HELPER(Ceil, functor::ceil<T>); |
398 | REGISTER_COMPUTE_FN_HELPER(Cos, functor::cos<T>); |
399 | REGISTER_COMPUTE_FN_HELPER(Cosh, functor::cosh<T>); |
400 | REGISTER_COMPUTE_FN_HELPER(Expm1, functor::expm1<T>); |
401 | REGISTER_COMPUTE_FN_HELPER(Exp, functor::exp<T>); |
402 | REGISTER_COMPUTE_FN_HELPER(Floor, functor::floor<T>); |
403 | REGISTER_COMPUTE_FN_HELPER(Inv, functor::inverse<T>); |
404 | REGISTER_COMPUTE_FN_HELPER(Log, functor::log<T>); |
405 | REGISTER_COMPUTE_FN_HELPER(Log1p, functor::log1p<T>); |
406 | REGISTER_COMPUTE_FN_HELPER(Neg, functor::neg<T>); |
407 | REGISTER_COMPUTE_FN_HELPER(Reciprocal, functor::inverse<T>); |
408 | REGISTER_COMPUTE_FN_HELPER(Rint, functor::rint<T>); |
409 | REGISTER_COMPUTE_FN_HELPER(Round, functor::round<T>); |
410 | REGISTER_COMPUTE_FN_HELPER(Rsqrt, functor::rsqrt<T>); |
411 | REGISTER_COMPUTE_FN_HELPER(Sigmoid, functor::sigmoid<T>); |
412 | REGISTER_COMPUTE_FN_HELPER(Sin, functor::sin<T>); |
413 | REGISTER_COMPUTE_FN_HELPER(Sinh, functor::sinh<T>); |
414 | REGISTER_COMPUTE_FN_HELPER(Sqrt, functor::sqrt<T>); |
415 | REGISTER_COMPUTE_FN_HELPER(Square, functor::square<T>); |
416 | REGISTER_COMPUTE_FN_HELPER(Tan, functor::tan<T>); |
417 | REGISTER_COMPUTE_FN_HELPER(Tanh, functor::tanh<T>); |
418 | // clang-format on |
419 | }; |
420 | |
421 | // Register the CPU kernels. |
422 | #define REGISTER_CPU(T) \ |
423 | REGISTER_KERNEL_BUILDER( \ |
424 | Name("_UnaryOpsComposition").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ |
425 | UnaryOpsComposition<T>); |
426 | |
427 | REGISTER_CPU(float); |
428 | REGISTER_CPU(Eigen::half); |
429 | REGISTER_CPU(double); |
430 | |
431 | #undef REGISTER_CPU |
432 | |
433 | } // namespace tensorflow |
434 | |