1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#define EIGEN_USE_THREADS
19
20#include "tensorflow/core/framework/op.h"
21#include "tensorflow/core/framework/op_kernel.h"
22#include "tensorflow/core/framework/type_traits.h"
23#include "tensorflow/core/framework/types.h"
24#include "tensorflow/core/kernels/meta_support.h"
25#include "tensorflow/core/kernels/quantization_utils.h"
26#include "tensorflow/core/lib/core/errors.h"
27#include "tensorflow/core/platform/bfloat16.h"
28
29namespace {
30enum {
31 QUANTIZE_MODE_MIN_COMBINED,
32 QUANTIZE_MODE_MIN_FIRST,
33 QUANTIZE_MODE_SCALED,
34};
35} // namespace
36
37namespace tensorflow {
38
39typedef Eigen::ThreadPoolDevice CPUDevice;
40
41template <typename T>
42T Cast(float v) {
43 return v;
44}
45
46template <>
47bfloat16 Cast<bfloat16>(float v) {
48 return bfloat16(v);
49}
50
51template <typename Device, typename T, typename S>
52class DequantizeOp : public OpKernel {
53 public:
54 explicit DequantizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
55 string mode_string;
56 OP_REQUIRES_OK(ctx, ctx->GetAttr("mode", &mode_string));
57 OP_REQUIRES(
58 ctx,
59 (ctx->output_type(0) == DT_FLOAT || ctx->output_type(0) == DT_BFLOAT16),
60 errors::InvalidArgument("Output type must be bfloat16 or float,"
61 " is '" +
62 DataTypeString(ctx->output_type(0)) + "'"));
63
64 need_cast_ = true;
65 if (ctx->output_type(0) == DT_FLOAT) {
66 need_cast_ = false;
67 OP_REQUIRES(ctx,
68 (mode_string == "MIN_COMBINED" ||
69 mode_string == "MIN_FIRST" || mode_string == "SCALED"),
70 errors::InvalidArgument("Mode string must be 'MIN_COMBINED',"
71 " 'MIN_FIRST', or 'SCALED', is '" +
72 mode_string + "'"));
73 } else {
74 OP_REQUIRES(
75 ctx, (mode_string == "MIN_COMBINED"),
76 errors::InvalidArgument("When output type is bfloat16, Mode"
77 " string must be 'MIN_COMBINED', is '" +
78 mode_string + "'"));
79 }
80
81 if (mode_string == "MIN_COMBINED") {
82 mode_ = QUANTIZE_MODE_MIN_COMBINED;
83 } else if (mode_string == "MIN_FIRST") {
84 mode_ = QUANTIZE_MODE_MIN_FIRST;
85 } else if (mode_string == "SCALED") {
86 mode_ = QUANTIZE_MODE_SCALED;
87 }
88 OP_REQUIRES_OK(ctx, ctx->GetAttr("narrow_range", &narrow_range_));
89 OP_REQUIRES_OK(ctx, ctx->GetAttr("axis", &axis_));
90 }
91
92 void Compute(OpKernelContext* ctx) override {
93 const Tensor& input = ctx->input(0);
94 const Tensor& input_min_tensor = ctx->input(1);
95 const Tensor& input_max_tensor = ctx->input(2);
96
97 OP_REQUIRES(
98 ctx, axis_ < input.dims(),
99 errors::InvalidArgument("Axis must be less than input dimension(",
100 input.dims(), "), got ", axis_));
101
102 int num_slices = 1;
103 if (axis_ > -1) {
104 num_slices = input.dim_size(axis_);
105 }
106 OP_REQUIRES(ctx, input_min_tensor.NumElements() == num_slices,
107 errors::InvalidArgument(
108 "input_min_tensor must have as many elements as input on "
109 "the dequantization axis (",
110 axis_, "), got ", input_min_tensor.NumElements(),
111 ", expected ", num_slices));
112 OP_REQUIRES(ctx, input_max_tensor.NumElements() == num_slices,
113 errors::InvalidArgument(
114 "input_max_tensor must have as many elements as input on "
115 "the dequantization axis (",
116 axis_, "), got ", input_max_tensor.NumElements(),
117 ", expected ", num_slices));
118
119 Tensor* output = nullptr;
120 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, input.shape(), &output));
121 Tensor float_output =
122 need_cast_ ? tensorflow::Tensor(DT_FLOAT, input.shape()) : *output;
123 if (num_slices == 1) {
124 const float min_range = input_min_tensor.flat<float>()(0);
125 const float max_range = input_max_tensor.flat<float>()(0);
126 DequantizeTensor(ctx, input, min_range, max_range, &float_output);
127 } else {
128 OP_REQUIRES(ctx, mode_ != QUANTIZE_MODE_MIN_FIRST,
129 errors::Unimplemented("MIN_FIRST mode is not implemented for "
130 "Dequantize with axis != -1."));
131
132 int64_t pre_dim = 1, post_dim = 1;
133 for (int i = 0; i < axis_; ++i) {
134 pre_dim *= float_output.dim_size(i);
135 }
136 for (int i = axis_ + 1; i < float_output.dims(); ++i) {
137 post_dim *= float_output.dim_size(i);
138 }
139 auto input_tensor = input.template bit_casted_shaped<T, 3>(
140 {pre_dim, num_slices, post_dim});
141 auto output_tensor =
142 float_output.flat_inner_outer_dims<float, 3>(axis_ - 1);
143 auto min_ranges = input_min_tensor.vec<float>();
144 auto max_ranges = input_max_tensor.vec<float>();
145 for (int i = 0; i < num_slices; ++i) {
146 DequantizeSlice(ctx->eigen_device<Device>(), ctx,
147 input_tensor.template chip<1>(i), min_ranges(i),
148 max_ranges(i), output_tensor.template chip<1>(i));
149 }
150 }
151 if (need_cast_) {
152 S* out_ptr = output->flat<S>().data();
153 float* in_ptr = float_output.flat<float>().data();
154 for (int64_t i = 0; i < float_output.NumElements(); ++i) {
155 out_ptr[i] = static_cast<S>(in_ptr[i]);
156 }
157 }
158 }
159
160 void DequantizeTensor(OpKernelContext* ctx, const Tensor& input,
161 const float min_range, const float max_range,
162 Tensor* output) {
163 const float half_range =
164 !std::is_signed<T>::value
165 ? 0.0f
166 : (static_cast<float>(std::numeric_limits<T>::max()) -
167 std::numeric_limits<T>::min() + 1) /
168 2.0f;
169
170 if (mode_ == QUANTIZE_MODE_MIN_COMBINED) {
171 const float scale_factor =
172 (max_range - min_range) /
173 (static_cast<float>(std::numeric_limits<T>::max()) -
174 std::numeric_limits<T>::min());
175
176 const auto& input_tensor = input.flat<T>();
177 output->flat<float>() =
178 ((input_tensor.template cast<float>() + half_range) * scale_factor) +
179 min_range;
180
181 } else if (mode_ == QUANTIZE_MODE_MIN_FIRST) {
182 if (meta::IsSupportedAndEnabled() && std::is_same<T, quint8>()) {
183 auto input_ui8_array = input.flat<quint8>();
184 meta::Dequantize(ctx, input_ui8_array.data(), input_ui8_array.size(),
185 min_range, max_range, output->flat<float>().data());
186 } else {
187 QuantizedTensorToFloatInPlaceUsingEigen<T>(
188 ctx->template eigen_device<Device>(), input, min_range, max_range,
189 output);
190 }
191 } else if (mode_ == QUANTIZE_MODE_SCALED) {
192 const int min_output_value =
193 std::numeric_limits<T>::min() + (narrow_range_ ? 1 : 0);
194 const float scale_factor =
195 std::numeric_limits<T>::min() == 0
196 ? (max_range / std::numeric_limits<T>::max())
197 : std::max(min_range / min_output_value,
198 max_range / std::numeric_limits<T>::max());
199 const auto& input_tensor = input.flat<T>();
200 output->flat<float>() =
201 input_tensor.template cast<int>().template cast<float>() *
202 scale_factor;
203 }
204 }
205
206 template <typename ConstVec, typename Vec>
207 void DequantizeSlice(const Device& d, OpKernelContext* ctx,
208 const ConstVec& input, float min_range, float max_range,
209 Vec output) {
210 // TODO(pauldonnelly): Factor out the similar calculations in quantize,
211 // dequantize and quantize_and_dequantize ops.
212 const float half_range =
213 !std::is_signed<T>::value
214 ? 0.0f
215 : (static_cast<float>(std::numeric_limits<T>::max()) -
216 std::numeric_limits<T>::min() + 1) /
217 2.0f;
218
219 if (mode_ == QUANTIZE_MODE_MIN_COMBINED) {
220 const float scale_factor =
221 (max_range - min_range) /
222 (static_cast<float>(std::numeric_limits<T>::max()) -
223 std::numeric_limits<T>::min());
224
225 output.device(d) =
226 ((input.template cast<float>() + half_range) * scale_factor) +
227 min_range;
228 } else if (mode_ == QUANTIZE_MODE_SCALED) {
229 const int min_output_value =
230 std::numeric_limits<T>::min() + (narrow_range_ ? 1 : 0);
231 const float scale_factor =
232 std::numeric_limits<T>::min() == 0
233 ? (max_range / std::numeric_limits<T>::max())
234 : std::max(min_range / min_output_value,
235 max_range / std::numeric_limits<T>::max());
236 output.device(d) = input.template cast<float>() * scale_factor;
237 }
238 }
239
240 private:
241 int mode_;
242 int axis_;
243 bool narrow_range_;
244 bool need_cast_;
245};
246
247REGISTER_KERNEL_BUILDER(Name("Dequantize")
248 .Device(DEVICE_CPU)
249 .TypeConstraint<quint8>("T")
250 .TypeConstraint<float>("dtype"),
251 DequantizeOp<CPUDevice, quint8, float>);
252REGISTER_KERNEL_BUILDER(Name("Dequantize")
253 .Device(DEVICE_CPU)
254 .TypeConstraint<qint8>("T")
255 .TypeConstraint<float>("dtype"),
256 DequantizeOp<CPUDevice, qint8, float>);
257REGISTER_KERNEL_BUILDER(Name("Dequantize")
258 .Device(DEVICE_CPU)
259 .TypeConstraint<quint16>("T")
260 .TypeConstraint<float>("dtype"),
261 DequantizeOp<CPUDevice, quint16, float>);
262REGISTER_KERNEL_BUILDER(Name("Dequantize")
263 .Device(DEVICE_CPU)
264 .TypeConstraint<qint16>("T")
265 .TypeConstraint<float>("dtype"),
266 DequantizeOp<CPUDevice, qint16, float>);
267REGISTER_KERNEL_BUILDER(Name("Dequantize")
268 .Device(DEVICE_CPU)
269 .TypeConstraint<qint32>("T")
270 .TypeConstraint<float>("dtype"),
271 DequantizeOp<CPUDevice, qint32, float>);
272
273REGISTER_KERNEL_BUILDER(Name("Dequantize")
274 .Device(DEVICE_CPU)
275 .TypeConstraint<quint8>("T")
276 .TypeConstraint<bfloat16>("dtype"),
277 DequantizeOp<CPUDevice, quint8, bfloat16>);
278REGISTER_KERNEL_BUILDER(Name("Dequantize")
279 .Device(DEVICE_CPU)
280 .TypeConstraint<qint8>("T")
281 .TypeConstraint<bfloat16>("dtype"),
282 DequantizeOp<CPUDevice, qint8, bfloat16>);
283REGISTER_KERNEL_BUILDER(Name("Dequantize")
284 .Device(DEVICE_CPU)
285 .TypeConstraint<quint16>("T")
286 .TypeConstraint<bfloat16>("dtype"),
287 DequantizeOp<CPUDevice, quint16, bfloat16>);
288REGISTER_KERNEL_BUILDER(Name("Dequantize")
289 .Device(DEVICE_CPU)
290 .TypeConstraint<qint16>("T")
291 .TypeConstraint<bfloat16>("dtype"),
292 DequantizeOp<CPUDevice, qint16, bfloat16>);
293REGISTER_KERNEL_BUILDER(Name("Dequantize")
294 .Device(DEVICE_CPU)
295 .TypeConstraint<qint32>("T")
296 .TypeConstraint<bfloat16>("dtype"),
297 DequantizeOp<CPUDevice, qint32, bfloat16>);
298} // namespace tensorflow
299