1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/lite/c/builtin_op_data.h" |
16 | #include "tensorflow/lite/c/common.h" |
17 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
18 | #include "tensorflow/lite/kernels/internal/tensor.h" |
19 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
20 | #include "tensorflow/lite/kernels/internal/types.h" |
21 | #include "tensorflow/lite/kernels/kernel_util.h" |
22 | |
23 | namespace tflite { |
24 | namespace ops { |
25 | namespace builtin { |
26 | namespace fake_quant { |
27 | |
28 | // This file has reference implementation of FakeQuant. |
29 | enum KernelType { |
30 | kReference, |
31 | }; |
32 | |
33 | struct OpContext { |
34 | OpContext(TfLiteContext* context, TfLiteNode* node) { |
35 | input = GetInput(context, node, 0); |
36 | output = GetOutput(context, node, 0); |
37 | } |
38 | const TfLiteTensor* input; |
39 | TfLiteTensor* output; |
40 | }; |
41 | |
42 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
43 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); |
44 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); |
45 | |
46 | const auto* params = |
47 | reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data); |
48 | |
49 | if (params->narrow_range) { |
50 | TF_LITE_KERNEL_LOG( |
51 | context, |
52 | "narrow_range FakeQuant is not currently supported at runtime. " |
53 | "narrow_range is only meant to be applied to weights, not activations" ); |
54 | return kTfLiteError; |
55 | } |
56 | |
57 | OpContext op_context(context, node); |
58 | TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims); |
59 | op_context.output->type = op_context.input->type; |
60 | return context->ResizeTensor(context, op_context.output, output_dims); |
61 | } |
62 | |
63 | template <KernelType kernel_type> |
64 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
65 | OpContext op_context(context, node); |
66 | |
67 | const auto* params = |
68 | reinterpret_cast<TfLiteFakeQuantParams*>(node->builtin_data); |
69 | |
70 | tflite::FakeQuantParams op_params; |
71 | op_params.num_bits = params->num_bits; |
72 | op_params.minmax.min = params->min; |
73 | op_params.minmax.max = params->max; |
74 | reference_ops::FakeQuant(op_params, GetTensorShape(op_context.input), |
75 | GetTensorData<float>(op_context.input), |
76 | GetTensorShape(op_context.output), |
77 | GetTensorData<float>(op_context.output)); |
78 | |
79 | return kTfLiteOk; |
80 | } |
81 | |
82 | } // namespace fake_quant |
83 | |
84 | TfLiteRegistration* Register_FAKE_QUANT_REF() { |
85 | static TfLiteRegistration r = {nullptr, nullptr, fake_quant::Prepare, |
86 | fake_quant::Eval<fake_quant::kReference>}; |
87 | return &r; |
88 | } |
89 | |
90 | TfLiteRegistration* Register_FAKE_QUANT() { return Register_FAKE_QUANT_REF(); } |
91 | |
92 | } // namespace builtin |
93 | } // namespace ops |
94 | } // namespace tflite |
95 | |