1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/kernels/cwise_op_clip.h"
17
18namespace tensorflow {
19
20typedef Eigen::ThreadPoolDevice CPUDevice;
21typedef Eigen::GpuDevice GPUDevice;
22
23// Basic coefficient-wise tenary operations.
24// This is the case for example of the clip_by_value.
25// Device: E.g., CPUDevice, GPUDevice.
26// Functor: defined above. E.g., functor::clip.
27template <typename Device, typename T>
28class ClipOp : public OpKernel {
29 public:
30 explicit ClipOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
31
32 void Compute(OpKernelContext* ctx) override {
33 const Tensor& in0 = ctx->input(0);
34 const Tensor& in1 = ctx->input(1);
35 const Tensor& in2 = ctx->input(2);
36 OP_REQUIRES(ctx, (in0.shape() == in1.shape() ||
37 TensorShapeUtils::IsScalar(in1.shape())) &&
38 (in0.shape() == in2.shape() ||
39 TensorShapeUtils::IsScalar(in2.shape())),
40 errors::InvalidArgument(
41 "clip_value_min and clip_value_max must be either of "
42 "the same shape as input, or a scalar. ",
43 "input shape: ", in0.shape().DebugString(),
44 "clip_value_min shape: ", in1.shape().DebugString(),
45 "clip_value_max shape: ", in2.shape().DebugString()));
46
47 Tensor* out = nullptr;
48 OP_REQUIRES_OK(
49 ctx, ctx->forward_input_or_allocate_output({0}, 0, in0.shape(), &out));
50 if (out->NumElements() == 0) return; // Nothing to do for empty output
51
52 auto in0_flat = in0.flat<T>();
53 auto in1_flat = in1.flat<T>();
54 auto in2_flat = in2.flat<T>();
55 auto out_flat = out->flat<T>();
56 const Device& d = ctx->eigen_device<Device>();
57
58 if (in1.shape() == in2.shape()) {
59 if (in0.shape() == in1.shape()) {
60 functor::TernaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
61 out_flat);
62 } else {
63 functor::UnaryClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
64 out_flat);
65 }
66 } else {
67 if (in0.shape() == in1.shape()) {
68 functor::BinaryLeftClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
69 out_flat);
70 } else {
71 functor::BinaryRightClipOp<Device, T>()(d, in0_flat, in1_flat, in2_flat,
72 out_flat);
73 }
74 }
75 }
76};
77
78namespace functor {
79// Unary functor for clip [Tensor, Scalar, Scalar]
80template <typename T>
81struct UnaryClipFunc {
82 UnaryClipFunc(const T& value_min, const T& value_max)
83 : value_min(value_min), value_max(value_max) {}
84 const T operator()(const T& value) const {
85 return std::max(std::min(value, value_max), value_min);
86 }
87 T value_min;
88 T value_max;
89};
90template <typename T>
91struct UnaryClipOp<CPUDevice, T> {
92 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat,
93 typename TTypes<T>::ConstFlat& in1_flat,
94 typename TTypes<T>::ConstFlat& in2_flat,
95 typename TTypes<T>::Flat& out_flat) const {
96 out_flat = in0_flat.unaryExpr(UnaryClipFunc<T>(in1_flat(0), in2_flat(0)));
97 }
98};
99
100// Binary functor for clip [Tensor, Scalar, Tensor]
101template <typename T>
102struct BinaryRightClipFunc {
103 explicit BinaryRightClipFunc(const T& value_min) : value_min(value_min) {}
104 const T operator()(const T& value, const T& value_max) const {
105 return std::max(std::min(value, value_max), value_min);
106 }
107 T value_min;
108};
109template <typename T>
110struct BinaryRightClipOp<CPUDevice, T> {
111 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat,
112 typename TTypes<T>::ConstFlat& in1_flat,
113 typename TTypes<T>::ConstFlat& in2_flat,
114 typename TTypes<T>::Flat& out_flat) const {
115 out_flat =
116 in0_flat.binaryExpr(in2_flat, BinaryRightClipFunc<T>(in1_flat(0)));
117 }
118};
119
120// Binary functor for clip [Tensor, Tensor, Scalar]
121template <typename T>
122struct BinaryLeftClipFunc {
123 explicit BinaryLeftClipFunc(const T& value_max) : value_max(value_max) {}
124 const T operator()(const T& value, const T& value_min) const {
125 return std::max(std::min(value, value_max), value_min);
126 }
127 T value_max;
128};
129template <typename T>
130struct BinaryLeftClipOp<CPUDevice, T> {
131 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat,
132 typename TTypes<T>::ConstFlat& in1_flat,
133 typename TTypes<T>::ConstFlat& in2_flat,
134 typename TTypes<T>::Flat& out_flat) const {
135 out_flat =
136 in0_flat.binaryExpr(in1_flat, BinaryLeftClipFunc<T>(in2_flat(0)));
137 }
138};
139
140// Ternary functor for clip [Tensor, Tensor, Tensor]
141template <typename T>
142struct TernaryClipOp<CPUDevice, T> {
143 void operator()(const CPUDevice& d, typename TTypes<T>::ConstFlat& in0_flat,
144 typename TTypes<T>::ConstFlat& in1_flat,
145 typename TTypes<T>::ConstFlat& in2_flat,
146 typename TTypes<T>::Flat& out_flat) const {
147 out_flat.device(d) = in0_flat.cwiseMin(in2_flat).cwiseMax(in1_flat);
148 }
149};
150
151#define INSTANTIATE_CPU(T) \
152 template struct UnaryClipOp<CPUDevice, T>; \
153 template struct BinaryRightClipOp<CPUDevice, T>; \
154 template struct BinaryLeftClipOp<CPUDevice, T>; \
155 template struct TernaryClipOp<CPUDevice, T>;
156INSTANTIATE_CPU(Eigen::half);
157INSTANTIATE_CPU(float);
158INSTANTIATE_CPU(double);
159INSTANTIATE_CPU(bfloat16);
160INSTANTIATE_CPU(int8);
161INSTANTIATE_CPU(int16);
162INSTANTIATE_CPU(int32);
163INSTANTIATE_CPU(int64_t);
164INSTANTIATE_CPU(uint8);
165INSTANTIATE_CPU(uint16);
166#undef INSTANTIATE_CPU
167} // namespace functor
168
169#define REGISTER_CPU_KERNEL(type) \
170 REGISTER_KERNEL_BUILDER( \
171 Name("ClipByValue").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
172 ClipOp<CPUDevice, type>);
173
174REGISTER_CPU_KERNEL(Eigen::half);
175REGISTER_CPU_KERNEL(float);
176REGISTER_CPU_KERNEL(double);
177REGISTER_CPU_KERNEL(bfloat16);
178REGISTER_CPU_KERNEL(int8);
179REGISTER_CPU_KERNEL(int16);
180REGISTER_CPU_KERNEL(int32);
181REGISTER_CPU_KERNEL(int64_t);
182REGISTER_CPU_KERNEL(uint8);
183REGISTER_CPU_KERNEL(uint16);
184#undef REGISTER_CPU_KERNEL
185
186#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
187
188#define REGISTER_GPU_KERNEL(type) \
189 REGISTER_KERNEL_BUILDER( \
190 Name("ClipByValue").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
191 ClipOp<GPUDevice, type>);
192REGISTER_GPU_KERNEL(Eigen::half);
193REGISTER_GPU_KERNEL(float);
194REGISTER_GPU_KERNEL(double);
195REGISTER_GPU_KERNEL(int8);
196REGISTER_GPU_KERNEL(int16);
197REGISTER_GPU_KERNEL(int64_t);
198REGISTER_GPU_KERNEL(uint8);
199REGISTER_GPU_KERNEL(uint16);
200
201// A special GPU kernel for int32.
202// TODO(b/25387198): Also enable int32 in device memory. This kernel
203// registration requires all int32 inputs and outputs to be in host memory.
204REGISTER_KERNEL_BUILDER(Name("ClipByValue")
205 .Device(DEVICE_GPU)
206 .HostMemory("t")
207 .HostMemory("clip_value_min")
208 .HostMemory("clip_value_max")
209 .HostMemory("output")
210 .TypeConstraint<int32>("T"),
211 ClipOp<CPUDevice, int32>);
212
213#undef REGISTER_GPU_KERNEL
214#endif
215
216} // namespace tensorflow
217