1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/framework/tensor_types.h" |
23 | #include "tensorflow/core/util/tensor_format.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace functor { |
27 | |
28 | // FusedBatchNormEx op supports side inputs and activations: |
29 | // (1) batch_norm + activation |
30 | // (2) batch norm + side input + activation |
31 | enum class FusedBatchNormActivationMode { kIdentity, kRelu }; |
32 | |
33 | std::string ToString(FusedBatchNormActivationMode activation_mode); |
34 | |
35 | Status ParseActivationMode(OpKernelConstruction* context, |
36 | FusedBatchNormActivationMode* activation_mode); |
37 | |
38 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
39 | |
40 | // This is a functor to launch custom CUDA kernel for FusedBatchNorm with side |
41 | // input and activation when 'is_training=False'. In training we rely on cuDNN. |
42 | template <typename Device, typename T, typename U> |
43 | struct FusedBatchNormInferenceFunctor { |
44 | void operator()(OpKernelContext* context, TensorFormat tensor_format, |
45 | typename TTypes<T, 4>::ConstTensor in, |
46 | typename TTypes<U>::ConstVec scale, |
47 | typename TTypes<U>::ConstVec offset, |
48 | typename TTypes<U>::ConstVec estimated_mean, |
49 | typename TTypes<U>::ConstVec estimated_variance, |
50 | typename TTypes<T, 4>::ConstTensor side_input, U epsilon, |
51 | FusedBatchNormActivationMode activation_mode, |
52 | typename TTypes<T, 4>::Tensor out); |
53 | }; |
54 | |
55 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
56 | |
57 | // Functor used by FusedBatchNormGradOp to do the computations when |
58 | // is_training=False. |
59 | template <typename Device, typename T, typename U> |
60 | struct FusedBatchNormFreezeGrad { |
61 | void operator()(OpKernelContext* context, const Tensor& y_backprop_input, |
62 | const Tensor& x_input, const Tensor& scale_input, |
63 | const Tensor& pop_mean_input, |
64 | const Tensor& pop_variance_input, U epsilon, |
65 | Tensor* x_backprop_output, Tensor* scale_backprop_output, |
66 | Tensor* offset_backprop_output) {} |
67 | }; |
68 | |
69 | } // namespace functor |
70 | } // namespace tensorflow |
71 | |
72 | #endif // TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_ |
73 | |