1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/kernels/cwise_op_clip.h" |
17 | |
18 | namespace tensorflow { |
19 | |
20 | typedef Eigen::ThreadPoolDevice CPUDevice; |
21 | typedef Eigen::GpuDevice GPUDevice; |
22 | |
23 | // Basic coefficient-wise tenary operations. |
24 | // This is the case for example of the clip_by_value. |
25 | // Device: E.g., CPUDevice, GPUDevice. |
26 | // Functor: defined above. E.g., functor::clip. |
27 | template <typename Device, typename T> |
28 | class ClipOp : public OpKernel { |
29 | public: |
30 | explicit ClipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
31 | |
32 | void Compute(OpKernelContext* ctx) override { |
33 | const Tensor& in0 = ctx->input(0); |
34 | const Tensor& in1 = ctx->input(1); |
35 | const Tensor& in2 = ctx->input(2); |
36 | OP_REQUIRES(ctx, (in0.shape() == in1.shape() || |
37 | TensorShapeUtils::IsScalar(in1.shape())) && |
38 | (in0.shape() == in2.shape() || |
39 | TensorShapeUtils::IsScalar(in2.shape())), |
40 | errors::InvalidArgument( |
41 | "clip_value_min and clip_value_max must be either of " |
42 | "the same shape as input, or a scalar. " , |
43 | "input shape: " , in0.shape().DebugString(), |
44 | "clip_value_min shape: " , in1.shape().DebugString(), |
45 | "clip_value_max shape: " , in2.shape().DebugString())); |
46 | |
47 | Tensor* out = nullptr; |
48 | OP_REQUIRES_OK( |
49 | ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out)); |
50 | if (out->NumElements() == 0) return; // Nothing to do for empty output |
51 | |
52 | auto in0_flat = in0.flat<T>(); |
53 | auto in1_flat = in1.flat<T>(); |
54 | auto in2_flat = in2.flat<T>(); |
55 | auto out_flat = out->flat<T>(); |
56 | const Device& d = ctx->eigen_device<Device>(); |
57 | |
58 | if (in1.shape() == in2.shape()) { |
59 | if (in0.shape() == in1.shape()) { |
60 | functor::TernaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat, |
61 | out_flat); |
62 | } else { |
63 | functor::UnaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat, |
64 | out_flat); |
65 | } |
66 | } else { |
67 | if (in0.shape() == in1.shape()) { |
68 | functor::BinaryLeftClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat, |
69 | out_flat); |
70 | } else { |
71 | functor::BinaryRightClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat, |
72 | out_flat); |
73 | } |
74 | } |
75 | } |
76 | }; |
77 | |
78 | namespace functor { |
79 | // Unary functor for clip [Tensor, Scalar, Scalar] |
80 | template <typename T> |
81 | struct UnaryClipFunc { |
82 | UnaryClipFunc(const T& value_min, const T& value_max) |
83 | : value_min(value_min), value_max(value_max) {} |
84 | const T operator()(const T& value) const { |
85 | return std::max(std::min(value, value_max), value_min); |
86 | } |
87 | T value_min; |
88 | T value_max; |
89 | }; |
90 | template <typename T> |
91 | struct UnaryClipOp<CPUDevice, T> { |
92 | void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat, |
93 | typename TTypes<T>::ConstFlat& in1_flat, |
94 | typename TTypes<T>::ConstFlat& in2_flat, |
95 | typename TTypes<T>::Flat& out_flat) const { |
96 | out_flat = in0_flat.unaryExpr(UnaryClipFunc<T>(in1_flat(0), in2_flat(0))); |
97 | } |
98 | }; |
99 | |
100 | // Binary functor for clip [Tensor, Scalar, Tensor] |
101 | template <typename T> |
102 | struct BinaryRightClipFunc { |
103 | explicit BinaryRightClipFunc(const T& value_min) : value_min(value_min) {} |
104 | const T operator()(const T& value, const T& value_max) const { |
105 | return std::max(std::min(value, value_max), value_min); |
106 | } |
107 | T value_min; |
108 | }; |
109 | template <typename T> |
110 | struct BinaryRightClipOp<CPUDevice, T> { |
111 | void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat, |
112 | typename TTypes<T>::ConstFlat& in1_flat, |
113 | typename TTypes<T>::ConstFlat& in2_flat, |
114 | typename TTypes<T>::Flat& out_flat) const { |
115 | out_flat = |
116 | in0_flat.binaryExpr(in2_flat, BinaryRightClipFunc<T>(in1_flat(0))); |
117 | } |
118 | }; |
119 | |
120 | // Binary functor for clip [Tensor, Tensor, Scalar] |
121 | template <typename T> |
122 | struct BinaryLeftClipFunc { |
123 | explicit BinaryLeftClipFunc(const T& value_max) : value_max(value_max) {} |
124 | const T operator()(const T& value, const T& value_min) const { |
125 | return std::max(std::min(value, value_max), value_min); |
126 | } |
127 | T value_max; |
128 | }; |
129 | template <typename T> |
130 | struct BinaryLeftClipOp<CPUDevice, T> { |
131 | void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat, |
132 | typename TTypes<T>::ConstFlat& in1_flat, |
133 | typename TTypes<T>::ConstFlat& in2_flat, |
134 | typename TTypes<T>::Flat& out_flat) const { |
135 | out_flat = |
136 | in0_flat.binaryExpr(in1_flat, BinaryLeftClipFunc<T>(in2_flat(0))); |
137 | } |
138 | }; |
139 | |
140 | // Ternary functor for clip [Tensor, Tensor, Tensor] |
141 | template <typename T> |
142 | struct TernaryClipOp<CPUDevice, T> { |
143 | void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat, |
144 | typename TTypes<T>::ConstFlat& in1_flat, |
145 | typename TTypes<T>::ConstFlat& in2_flat, |
146 | typename TTypes<T>::Flat& out_flat) const { |
147 | out_flat.device(d) = in0_flat.cwiseMin(in2_flat).cwiseMax(in1_flat); |
148 | } |
149 | }; |
150 | |
151 | #define INSTANTIATE_CPU(T) \ |
152 | template struct UnaryClipOp<CPUDevice, T>; \ |
153 | template struct BinaryRightClipOp<CPUDevice, T>; \ |
154 | template struct BinaryLeftClipOp<CPUDevice, T>; \ |
155 | template struct TernaryClipOp<CPUDevice, T>; |
156 | INSTANTIATE_CPU(Eigen::half); |
157 | INSTANTIATE_CPU(float); |
158 | INSTANTIATE_CPU(double); |
159 | INSTANTIATE_CPU(bfloat16); |
160 | INSTANTIATE_CPU(int8); |
161 | INSTANTIATE_CPU(int16); |
162 | INSTANTIATE_CPU(int32); |
163 | INSTANTIATE_CPU(int64_t); |
164 | INSTANTIATE_CPU(uint8); |
165 | INSTANTIATE_CPU(uint16); |
166 | #undef INSTANTIATE_CPU |
167 | } // namespace functor |
168 | |
169 | #define REGISTER_CPU_KERNEL(type) \ |
170 | REGISTER_KERNEL_BUILDER( \ |
171 | Name("ClipByValue").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ |
172 | ClipOp<CPUDevice, type>); |
173 | |
174 | REGISTER_CPU_KERNEL(Eigen::half); |
175 | REGISTER_CPU_KERNEL(float); |
176 | REGISTER_CPU_KERNEL(double); |
177 | REGISTER_CPU_KERNEL(bfloat16); |
178 | REGISTER_CPU_KERNEL(int8); |
179 | REGISTER_CPU_KERNEL(int16); |
180 | REGISTER_CPU_KERNEL(int32); |
181 | REGISTER_CPU_KERNEL(int64_t); |
182 | REGISTER_CPU_KERNEL(uint8); |
183 | REGISTER_CPU_KERNEL(uint16); |
184 | #undef REGISTER_CPU_KERNEL |
185 | |
186 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
187 | |
188 | #define REGISTER_GPU_KERNEL(type) \ |
189 | REGISTER_KERNEL_BUILDER( \ |
190 | Name("ClipByValue").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ |
191 | ClipOp<GPUDevice, type>); |
192 | REGISTER_GPU_KERNEL(Eigen::half); |
193 | REGISTER_GPU_KERNEL(float); |
194 | REGISTER_GPU_KERNEL(double); |
195 | REGISTER_GPU_KERNEL(int8); |
196 | REGISTER_GPU_KERNEL(int16); |
197 | REGISTER_GPU_KERNEL(int64_t); |
198 | REGISTER_GPU_KERNEL(uint8); |
199 | REGISTER_GPU_KERNEL(uint16); |
200 | |
201 | // A special GPU kernel for int32. |
202 | // TODO(b/25387198): Also enable int32 in device memory. This kernel |
203 | // registration requires all int32 inputs and outputs to be in host memory. |
204 | REGISTER_KERNEL_BUILDER(Name("ClipByValue" ) |
205 | .Device(DEVICE_GPU) |
206 | .HostMemory("t" ) |
207 | .HostMemory("clip_value_min" ) |
208 | .HostMemory("clip_value_max" ) |
209 | .HostMemory("output" ) |
210 | .TypeConstraint<int32>("T" ), |
211 | ClipOp<CPUDevice, int32>); |
212 | |
213 | #undef REGISTER_GPU_KERNEL |
214 | #endif |
215 | |
216 | } // namespace tensorflow |
217 | |