1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/kernels/dequantize.h" |
16 | |
17 | #include <stddef.h> |
18 | |
19 | #include "tensorflow/lite/c/common.h" |
20 | #include "tensorflow/lite/kernels/internal/optimized/neon_check.h" |
21 | #include "tensorflow/lite/kernels/kernel_util.h" |
22 | |
23 | namespace tflite { |
24 | namespace ops { |
25 | namespace builtin { |
26 | namespace dequantize { |
27 | |
28 | struct OpContext { |
29 | OpContext(TfLiteContext* context, TfLiteNode* node) { |
30 | input = GetInput(context, node, 0); |
31 | output = GetOutput(context, node, 0); |
32 | } |
33 | const TfLiteTensor* input; |
34 | TfLiteTensor* output; |
35 | }; |
36 | |
37 | struct OpData { |
38 | // This boolean value is only used when the input tensor is constant. |
39 | bool float_dequantized_weights_initialized; |
40 | }; |
41 | |
42 | void* Init(TfLiteContext* context, const char* buffer, size_t length) { |
43 | auto* op_data = new OpData(); |
44 | op_data->float_dequantized_weights_initialized = false; |
45 | return op_data; |
46 | } |
47 | |
48 | void Free(TfLiteContext* context, void* buffer) { |
49 | delete reinterpret_cast<OpData*>(buffer); |
50 | } |
51 | |
52 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
53 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
54 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
55 | |
56 | OpContext op_context(context, node); |
57 | |
58 | TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8 || |
59 | op_context.input->type == kTfLiteInt8 || |
60 | op_context.input->type == kTfLiteInt16 || |
61 | op_context.input->type == kTfLiteFloat16); |
62 | |
63 | if (op_context.input->type == kTfLiteInt16) { |
64 | TF_LITE_ENSURE_EQ(context, op_context.input->params.zero_point, 0); |
65 | } |
66 | |
67 | op_context.output->type = kTfLiteFloat32; |
68 | // If the input tensor is constant, we can persist the dequantized value in |
69 | // the output tensor. Otherwise we run dequantize upon each eval. |
70 | if (IsConstantTensor(op_context.input)) { |
71 | op_context.output->allocation_type = kTfLiteArenaRwPersistent; |
72 | } |
73 | return context->ResizeTensor(context, op_context.output, |
74 | TfLiteIntArrayCopy(op_context.input->dims)); |
75 | } |
76 | |
77 | template <KernelType kernel_type> |
78 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
79 | OpData* op_data = reinterpret_cast<OpData*>(node->user_data); |
80 | OpContext op_context(context, node); |
81 | if (IsConstantTensor(op_context.input) && |
82 | op_data->float_dequantized_weights_initialized) { |
83 | return kTfLiteOk; |
84 | } |
85 | |
86 | auto status = DequantizeImpl<kernel_type>(context, node, op_context.input, |
87 | op_context.output); |
88 | if (status != kTfLiteOk) { |
89 | return status; |
90 | } |
91 | |
92 | if (IsConstantTensor(op_context.input)) { |
93 | op_data->float_dequantized_weights_initialized = true; |
94 | } |
95 | return kTfLiteOk; |
96 | } |
97 | |
98 | } // namespace dequantize |
99 | |
100 | TfLiteRegistration* Register_DEQUANTIZE_OPT() { |
101 | static TfLiteRegistration r = { |
102 | dequantize::Init, dequantize::Free, dequantize::Prepare, |
103 | dequantize::Eval<dequantize::kGenericOptimized>}; |
104 | return &r; |
105 | } |
106 | |
107 | TfLiteRegistration* Register_DEQUANTIZE_REF() { |
108 | static TfLiteRegistration r = {dequantize::Init, dequantize::Free, |
109 | dequantize::Prepare, |
110 | dequantize::Eval<dequantize::kReference>}; |
111 | return &r; |
112 | } |
113 | |
114 | TfLiteRegistration* Register_DEQUANTIZE() { |
115 | #ifdef USE_NEON |
116 | return Register_DEQUANTIZE_OPT(); |
117 | #else |
118 | return Register_DEQUANTIZE_REF(); |
119 | #endif |
120 | } |
121 | |
122 | } // namespace builtin |
123 | } // namespace ops |
124 | } // namespace tflite |
125 | |