1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include <math.h> |
21 | |
22 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
23 | #include "tensorflow/core/framework/op.h" |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/framework/type_traits.h" |
26 | #include "tensorflow/core/framework/types.h" |
27 | #include "tensorflow/core/kernels/quantization_utils.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | typedef Eigen::ThreadPoolDevice CPUDevice; |
33 | |
34 | void CalculateUsedRange(const Tensor& input, qint32* used_min_quantized, |
35 | qint32* used_max_quantized) { |
36 | auto input_array = input.flat<qint32>(); |
37 | Eigen::Tensor<qint32, 0, Eigen::RowMajor> min = input_array.minimum(); |
38 | Eigen::Tensor<qint32, 0, Eigen::RowMajor> max = input_array.maximum(); |
39 | *used_min_quantized = min(); |
40 | *used_max_quantized = max(); |
41 | } |
42 | |
43 | class RequantizationRangeOp : public OpKernel { |
44 | public: |
45 | explicit RequantizationRangeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} |
46 | |
47 | void Compute(OpKernelContext* ctx) override { |
48 | const Tensor& input = ctx->input(0); |
49 | OP_REQUIRES(ctx, ctx->input(1).NumElements() > 0, |
50 | errors::InvalidArgument("Input min must not be empty." )); |
51 | OP_REQUIRES(ctx, ctx->input(2).NumElements() > 0, |
52 | errors::InvalidArgument("Input max must not be empty." )); |
53 | const float input_min_float = ctx->input(1).flat<float>()(0); |
54 | const float input_max_float = ctx->input(2).flat<float>()(0); |
55 | Tensor* output_min = nullptr; |
56 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output_min)); |
57 | Tensor* output_max = nullptr; |
58 | OP_REQUIRES_OK(ctx, ctx->allocate_output(1, TensorShape({}), &output_max)); |
59 | |
60 | qint32 used_min_quantized; |
61 | qint32 used_max_quantized; |
62 | CalculateUsedRange(input, &used_min_quantized, &used_max_quantized); |
63 | |
64 | // We want to make sure that the minimum is no larger than zero, so that the |
65 | // convolution operation can run efficiently. |
66 | const float used_min_float = std::min( |
67 | 0.0f, |
68 | QuantizedToFloat(used_min_quantized, input_min_float, input_max_float)); |
69 | const float used_max_float = |
70 | QuantizedToFloat(used_max_quantized, input_min_float, input_max_float); |
71 | |
72 | output_min->flat<float>().setConstant(used_min_float); |
73 | output_max->flat<float>().setConstant(used_max_float); |
74 | } |
75 | }; |
76 | |
77 | REGISTER_KERNEL_BUILDER(Name("RequantizationRange" ) |
78 | .Device(DEVICE_CPU) |
79 | .TypeConstraint<qint32>("Tinput" ), |
80 | RequantizationRangeOp); |
81 | |
82 | } // namespace tensorflow |
83 | |