1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_LITE_KERNELS_DEQUANTIZE_H_ |
16 | #define TENSORFLOW_LITE_KERNELS_DEQUANTIZE_H_ |
17 | |
18 | #include <stdint.h> |
19 | |
20 | #include "third_party/eigen3/Eigen/Core" |
21 | #include "tensorflow/lite/c/common.h" |
22 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
23 | #include "tensorflow/lite/kernels/internal/reference/dequantize.h" |
24 | #include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h" |
25 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
26 | #include "tensorflow/lite/kernels/internal/tensor.h" |
27 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
28 | #include "tensorflow/lite/kernels/internal/types.h" |
29 | |
30 | namespace tflite { |
31 | namespace ops { |
32 | namespace builtin { |
33 | namespace dequantize { |
34 | |
35 | // This file has two implementation of Dequantize. |
36 | enum KernelType { |
37 | kReference, |
38 | kGenericOptimized, |
39 | }; |
40 | |
41 | inline bool IsQuantizedPerChannel(const TfLiteTensor* input) { |
42 | if (input->quantization.type == kTfLiteAffineQuantization && |
43 | input->quantization.params) { |
44 | auto* quant_params = |
45 | reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params); |
46 | return (quant_params->scale && quant_params->scale->size > 1); |
47 | } |
48 | return false; |
49 | } |
50 | |
51 | inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context, |
52 | TfLiteNode* node, |
53 | const TfLiteTensor* input, |
54 | TfLiteTensor* output) { |
55 | const auto* quantization_params = |
56 | reinterpret_cast<const TfLiteAffineQuantization*>( |
57 | input->quantization.params); |
58 | PerChannelDequantizationParams per_channel_op_params; |
59 | per_channel_op_params.quantized_dimension = |
60 | quantization_params->quantized_dimension; |
61 | per_channel_op_params.scale = quantization_params->scale->data; |
62 | per_channel_op_params.zero_point = quantization_params->zero_point->data; |
63 | switch (input->type) { |
64 | case kTfLiteUInt8: |
65 | reference_ops::PerChannelDequantize<uint8_t>( |
66 | per_channel_op_params, GetTensorShape(input), |
67 | GetTensorData<uint8_t>(input), GetTensorShape(output), |
68 | GetTensorData<float>(output)); |
69 | break; |
70 | case kTfLiteInt8: |
71 | reference_ops::PerChannelDequantize<int8_t>( |
72 | per_channel_op_params, GetTensorShape(input), |
73 | GetTensorData<int8_t>(input), GetTensorShape(output), |
74 | GetTensorData<float>(output)); |
75 | break; |
76 | default: |
77 | TF_LITE_KERNEL_LOG(context, "Type %d not supported for per-channel." , |
78 | input->type); |
79 | return kTfLiteError; |
80 | } |
81 | return kTfLiteOk; |
82 | } |
83 | |
84 | template <KernelType kernel_type> |
85 | TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node, |
86 | const TfLiteTensor* input, TfLiteTensor* output) { |
87 | if (IsQuantizedPerChannel(input)) { |
88 | return PerChannelDequantizeImpl(context, node, input, output); |
89 | } |
90 | DequantizationParams op_params; |
91 | op_params.zero_point = input->params.zero_point; |
92 | op_params.scale = input->params.scale; |
93 | switch (input->type) { |
94 | case kTfLiteUInt8: |
95 | if (kernel_type == kReference) { |
96 | reference_ops::Dequantize( |
97 | op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
98 | GetTensorShape(output), GetTensorData<float>(output)); |
99 | } else { |
100 | optimized_ops::Dequantize( |
101 | op_params, GetTensorShape(input), GetTensorData<uint8_t>(input), |
102 | GetTensorShape(output), GetTensorData<float>(output)); |
103 | } |
104 | break; |
105 | case kTfLiteInt8: |
106 | if (kernel_type == kReference) { |
107 | reference_integer_ops::Dequantize<int8_t>( |
108 | op_params, GetTensorShape(input), GetTensorData<int8_t>(input), |
109 | GetTensorShape(output), GetTensorData<float>(output)); |
110 | } else { |
111 | optimized_ops::Dequantize( |
112 | op_params, GetTensorShape(input), GetTensorData<int8_t>(input), |
113 | GetTensorShape(output), GetTensorData<float>(output)); |
114 | } |
115 | break; |
116 | case kTfLiteInt16: |
117 | if (kernel_type == kReference) { |
118 | reference_integer_ops::Dequantize<int16_t>( |
119 | op_params, GetTensorShape(input), GetTensorData<int16_t>(input), |
120 | GetTensorShape(output), GetTensorData<float>(output)); |
121 | } else { |
122 | optimized_ops::Dequantize( |
123 | op_params, GetTensorShape(input), GetTensorData<int16_t>(input), |
124 | GetTensorShape(output), GetTensorData<float>(output)); |
125 | } |
126 | break; |
127 | case kTfLiteFloat16: { |
128 | const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>( |
129 | GetTensorData<TfLiteFloat16>(input)); |
130 | reference_ops::Dequantize(GetTensorShape(input), half_data, |
131 | GetTensorShape(output), |
132 | GetTensorData<float>(output)); |
133 | break; |
134 | } |
135 | default: |
136 | TF_LITE_KERNEL_LOG(context, "Type %d not supported." , input->type); |
137 | return kTfLiteError; |
138 | } |
139 | |
140 | return kTfLiteOk; |
141 | } |
142 | |
143 | } // namespace dequantize |
144 | } // namespace builtin |
145 | } // namespace ops |
146 | } // namespace tflite |
147 | |
148 | #endif // TENSORFLOW_LITE_KERNELS_DEQUANTIZE_H_ |
149 | |