1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/kernels/dequantize.h"
16
17#include <stddef.h>
18
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/optimized/neon_check.h"
21#include "tensorflow/lite/kernels/kernel_util.h"
22
23namespace tflite {
24namespace ops {
25namespace builtin {
26namespace dequantize {
27
28struct OpContext {
29 OpContext(TfLiteContext* context, TfLiteNode* node) {
30 input = GetInput(context, node, 0);
31 output = GetOutput(context, node, 0);
32 }
33 const TfLiteTensor* input;
34 TfLiteTensor* output;
35};
36
37struct OpData {
38 // This boolean value is only used when the input tensor is constant.
39 bool float_dequantized_weights_initialized;
40};
41
42void* Init(TfLiteContext* context, const char* buffer, size_t length) {
43 auto* op_data = new OpData();
44 op_data->float_dequantized_weights_initialized = false;
45 return op_data;
46}
47
48void Free(TfLiteContext* context, void* buffer) {
49 delete reinterpret_cast<OpData*>(buffer);
50}
51
52TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
53 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
54 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
55
56 OpContext op_context(context, node);
57
58 TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 ||
59 op_context.input->type == kTfLiteInt8 ||
60 op_context.input->type == kTfLiteInt16 ||
61 op_context.input->type == kTfLiteFloat16);
62
63 if (op_context.input->type == kTfLiteInt16) {
64 TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0);
65 }
66
67 op_context.output->type = kTfLiteFloat32;
68 // If the input tensor is constant, we can persist the dequantized value in
69 // the output tensor. Otherwise we run dequantize upon each eval.
70 if (IsConstantTensor(op_context.input)) {
71 op_context.output->allocation_type = kTfLiteArenaRwPersistent;
72 }
73 return context->ResizeTensor(context, op_context.output,
74 TfLiteIntArrayCopy(op_context.input->dims));
75}
76
77template <KernelType kernel_type>
78TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
79 OpData* op_data = reinterpret_cast<OpData*>(node->user_data);
80 OpContext op_context(context, node);
81 if (IsConstantTensor(op_context.input) &&
82 op_data->float_dequantized_weights_initialized) {
83 return kTfLiteOk;
84 }
85
86 auto status = DequantizeImpl<kernel_type>(context, node, op_context.input,
87 op_context.output);
88 if (status != kTfLiteOk) {
89 return status;
90 }
91
92 if (IsConstantTensor(op_context.input)) {
93 op_data->float_dequantized_weights_initialized = true;
94 }
95 return kTfLiteOk;
96}
97
98} // namespace dequantize
99
100TfLiteRegistration* Register_DEQUANTIZE_OPT() {
101 static TfLiteRegistration r = {
102 dequantize::Init, dequantize::Free, dequantize::Prepare,
103 dequantize::Eval<dequantize::kGenericOptimized>};
104 return &r;
105}
106
107TfLiteRegistration* Register_DEQUANTIZE_REF() {
108 static TfLiteRegistration r = {dequantize::Init, dequantize::Free,
109 dequantize::Prepare,
110 dequantize::Eval<dequantize::kReference>};
111 return &r;
112}
113
114TfLiteRegistration* Register_DEQUANTIZE() {
115#ifdef USE_NEON
116 return Register_DEQUANTIZE_OPT();
117#else
118 return Register_DEQUANTIZE_REF();
119#endif
120}
121
122} // namespace builtin
123} // namespace ops
124} // namespace tflite
125