1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // Implements a quantized version of the Relu6 operation. |
17 | #define EIGEN_USE_THREADS |
18 | |
19 | #include "tensorflow/core/framework/numeric_op.h" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/kernels/meta_support.h" |
23 | #include "tensorflow/core/kernels/quantization_utils.h" |
24 | #include "tensorflow/core/lib/core/errors.h" |
25 | |
26 | namespace tensorflow { |
27 | |
28 | template <typename T> |
29 | class QuantizedReluOp : public OpKernel { |
30 | public: |
31 | explicit QuantizedReluOp(OpKernelConstruction* context) : OpKernel(context) {} |
32 | |
33 | void Compute(OpKernelContext* context) override { |
34 | const Tensor& input = context->input(0); |
35 | const Tensor& min_input_tensor = context->input(1); |
36 | const Tensor& max_input_tensor = context->input(2); |
37 | |
38 | OP_REQUIRES( |
39 | context, TensorShapeUtils::IsScalar(min_input_tensor.shape()), |
40 | errors::InvalidArgument("`min_input` must be rank 0 but is rank " , |
41 | min_input_tensor.dims())); |
42 | OP_REQUIRES( |
43 | context, TensorShapeUtils::IsScalar(max_input_tensor.shape()), |
44 | errors::InvalidArgument("`max_input` must be rank 0 but is rank " , |
45 | max_input_tensor.dims())); |
46 | |
47 | const float min_input = min_input_tensor.scalar<float>()(); |
48 | const float max_input = max_input_tensor.scalar<float>()(); |
49 | |
50 | Tensor* output = nullptr; |
51 | OP_REQUIRES_OK(context, |
52 | context->allocate_output(0, input.shape(), &output)); |
53 | const T min_as_quantized = FloatToQuantized<T>(0.0f, min_input, max_input); |
54 | |
55 | if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) { |
56 | auto input_ui8_array = input.flat<quint8>(); |
57 | meta::Clamp(context, input_ui8_array.data(), input_ui8_array.size(), |
58 | min_as_quantized, 255, output->flat<quint8>().data()); |
59 | } else { |
60 | output->flat<T>().device(context->eigen_cpu_device()) = |
61 | input.flat<T>().cwiseMax(min_as_quantized).template cast<T>(); |
62 | } |
63 | |
64 | Tensor* output_min = nullptr; |
65 | OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min)); |
66 | output_min->flat<float>()(0) = min_input; |
67 | Tensor* output_max = nullptr; |
68 | OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max)); |
69 | output_max->flat<float>()(0) = max_input; |
70 | } |
71 | }; |
72 | |
73 | template <typename T> |
74 | class QuantizedRelu6Op : public OpKernel { |
75 | public: |
76 | explicit QuantizedRelu6Op(OpKernelConstruction* context) |
77 | : OpKernel(context) {} |
78 | |
79 | void Compute(OpKernelContext* context) override { |
80 | const Tensor& input = context->input(0); |
81 | const Tensor& min_input_tensor = context->input(1); |
82 | const Tensor& max_input_tensor = context->input(2); |
83 | |
84 | OP_REQUIRES( |
85 | context, TensorShapeUtils::IsScalar(min_input_tensor.shape()), |
86 | errors::InvalidArgument("`min_input` must be rank 0 but is rank " , |
87 | min_input_tensor.dims())); |
88 | OP_REQUIRES( |
89 | context, TensorShapeUtils::IsScalar(max_input_tensor.shape()), |
90 | errors::InvalidArgument("`max_input` must be rank 0 but is rank " , |
91 | max_input_tensor.dims())); |
92 | |
93 | const float min_input = min_input_tensor.scalar<float>()(); |
94 | const float max_input = max_input_tensor.scalar<float>()(); |
95 | |
96 | Tensor* output = nullptr; |
97 | OP_REQUIRES_OK(context, |
98 | context->allocate_output(0, input.shape(), &output)); |
99 | const T min_as_quantized = FloatToQuantized<T>(0.0f, min_input, max_input); |
100 | const T max_as_quantized = FloatToQuantized<T>(6.0f, min_input, max_input); |
101 | |
102 | if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) { |
103 | auto input_ui8_array = input.flat<quint8>(); |
104 | meta::Clamp(context, input_ui8_array.data(), input_ui8_array.size(), |
105 | min_as_quantized, max_as_quantized, |
106 | output->flat<quint8>().data()); |
107 | } else { |
108 | output->flat<T>().device(context->eigen_cpu_device()) = |
109 | input.flat<T>() |
110 | .cwiseMax(min_as_quantized) |
111 | .cwiseMin(max_as_quantized) |
112 | .template cast<T>(); |
113 | } |
114 | |
115 | Tensor* output_min = nullptr; |
116 | OP_REQUIRES_OK(context, context->allocate_output(1, {}, &output_min)); |
117 | output_min->flat<float>()(0) = min_input; |
118 | Tensor* output_max = nullptr; |
119 | OP_REQUIRES_OK(context, context->allocate_output(2, {}, &output_max)); |
120 | output_max->flat<float>()(0) = max_input; |
121 | } |
122 | }; |
123 | |
124 | REGISTER_KERNEL_BUILDER(Name("QuantizedRelu" ) |
125 | .Device(DEVICE_CPU) |
126 | .TypeConstraint<qint32>("Tinput" ) |
127 | .TypeConstraint<qint32>("out_type" ), |
128 | QuantizedReluOp<qint32>); |
129 | REGISTER_KERNEL_BUILDER(Name("QuantizedRelu" ) |
130 | .Device(DEVICE_CPU) |
131 | .TypeConstraint<quint8>("Tinput" ) |
132 | .TypeConstraint<quint8>("out_type" ), |
133 | QuantizedReluOp<quint8>); |
134 | |
135 | REGISTER_KERNEL_BUILDER(Name("QuantizedRelu6" ) |
136 | .Device(DEVICE_CPU) |
137 | .TypeConstraint<qint32>("Tinput" ) |
138 | .TypeConstraint<qint32>("out_type" ), |
139 | QuantizedRelu6Op<qint32>); |
140 | REGISTER_KERNEL_BUILDER(Name("QuantizedRelu6" ) |
141 | .Device(DEVICE_CPU) |
142 | .TypeConstraint<quint8>("Tinput" ) |
143 | .TypeConstraint<quint8>("out_type" ), |
144 | QuantizedRelu6Op<quint8>); |
145 | } // namespace tensorflow |
146 | |