1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_LITE_KERNELS_DEQUANTIZE_H_
16#define TENSORFLOW_LITE_KERNELS_DEQUANTIZE_H_
17
18#include <stdint.h>
19
20#include "third_party/eigen3/Eigen/Core"
21#include "tensorflow/lite/c/common.h"
22#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
23#include "tensorflow/lite/kernels/internal/reference/dequantize.h"
24#include "tensorflow/lite/kernels/internal/reference/integer_ops/dequantize.h"
25#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
26#include "tensorflow/lite/kernels/internal/tensor.h"
27#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
28#include "tensorflow/lite/kernels/internal/types.h"
29
30namespace tflite {
31namespace ops {
32namespace builtin {
33namespace dequantize {
34
35// This file has two implementation of Dequantize.
36enum KernelType {
37 kReference,
38 kGenericOptimized,
39};
40
41inline bool IsQuantizedPerChannel(const TfLiteTensor* input) {
42 if (input->quantization.type == kTfLiteAffineQuantization &&
43 input->quantization.params) {
44 auto* quant_params =
45 reinterpret_cast<TfLiteAffineQuantization*>(input->quantization.params);
46 return (quant_params->scale && quant_params->scale->size > 1);
47 }
48 return false;
49}
50
51inline TfLiteStatus PerChannelDequantizeImpl(TfLiteContext* context,
52 TfLiteNode* node,
53 const TfLiteTensor* input,
54 TfLiteTensor* output) {
55 const auto* quantization_params =
56 reinterpret_cast<const TfLiteAffineQuantization*>(
57 input->quantization.params);
58 PerChannelDequantizationParams per_channel_op_params;
59 per_channel_op_params.quantized_dimension =
60 quantization_params->quantized_dimension;
61 per_channel_op_params.scale = quantization_params->scale->data;
62 per_channel_op_params.zero_point = quantization_params->zero_point->data;
63 switch (input->type) {
64 case kTfLiteUInt8:
65 reference_ops::PerChannelDequantize<uint8_t>(
66 per_channel_op_params, GetTensorShape(input),
67 GetTensorData<uint8_t>(input), GetTensorShape(output),
68 GetTensorData<float>(output));
69 break;
70 case kTfLiteInt8:
71 reference_ops::PerChannelDequantize<int8_t>(
72 per_channel_op_params, GetTensorShape(input),
73 GetTensorData<int8_t>(input), GetTensorShape(output),
74 GetTensorData<float>(output));
75 break;
76 default:
77 TF_LITE_KERNEL_LOG(context, "Type %d not supported for per-channel.",
78 input->type);
79 return kTfLiteError;
80 }
81 return kTfLiteOk;
82}
83
84template <KernelType kernel_type>
85TfLiteStatus DequantizeImpl(TfLiteContext* context, TfLiteNode* node,
86 const TfLiteTensor* input, TfLiteTensor* output) {
87 if (IsQuantizedPerChannel(input)) {
88 return PerChannelDequantizeImpl(context, node, input, output);
89 }
90 DequantizationParams op_params;
91 op_params.zero_point = input->params.zero_point;
92 op_params.scale = input->params.scale;
93 switch (input->type) {
94 case kTfLiteUInt8:
95 if (kernel_type == kReference) {
96 reference_ops::Dequantize(
97 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
98 GetTensorShape(output), GetTensorData<float>(output));
99 } else {
100 optimized_ops::Dequantize(
101 op_params, GetTensorShape(input), GetTensorData<uint8_t>(input),
102 GetTensorShape(output), GetTensorData<float>(output));
103 }
104 break;
105 case kTfLiteInt8:
106 if (kernel_type == kReference) {
107 reference_integer_ops::Dequantize<int8_t>(
108 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
109 GetTensorShape(output), GetTensorData<float>(output));
110 } else {
111 optimized_ops::Dequantize(
112 op_params, GetTensorShape(input), GetTensorData<int8_t>(input),
113 GetTensorShape(output), GetTensorData<float>(output));
114 }
115 break;
116 case kTfLiteInt16:
117 if (kernel_type == kReference) {
118 reference_integer_ops::Dequantize<int16_t>(
119 op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
120 GetTensorShape(output), GetTensorData<float>(output));
121 } else {
122 optimized_ops::Dequantize(
123 op_params, GetTensorShape(input), GetTensorData<int16_t>(input),
124 GetTensorShape(output), GetTensorData<float>(output));
125 }
126 break;
127 case kTfLiteFloat16: {
128 const Eigen::half* half_data = reinterpret_cast<const Eigen::half*>(
129 GetTensorData<TfLiteFloat16>(input));
130 reference_ops::Dequantize(GetTensorShape(input), half_data,
131 GetTensorShape(output),
132 GetTensorData<float>(output));
133 break;
134 }
135 default:
136 TF_LITE_KERNEL_LOG(context, "Type %d not supported.", input->type);
137 return kTfLiteError;
138 }
139
140 return kTfLiteOk;
141}
142
143} // namespace dequantize
144} // namespace builtin
145} // namespace ops
146} // namespace tflite
147
148#endif // TENSORFLOW_LITE_KERNELS_DEQUANTIZE_H_
149