1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/lite/c/builtin_op_data.h"
16#include "tensorflow/lite/c/common.h"
17#include "tensorflow/lite/kernels/internal/compatibility.h"
18#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
19#include "tensorflow/lite/kernels/internal/reference/integer_ops/l2normalization.h"
20#include "tensorflow/lite/kernels/internal/reference/l2normalization.h"
21#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
22#include "tensorflow/lite/kernels/internal/tensor.h"
23#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
24#include "tensorflow/lite/kernels/internal/types.h"
25#include "tensorflow/lite/kernels/kernel_util.h"
26
27namespace tflite {
28namespace ops {
29namespace builtin {
30namespace l2norm {
31
32// This file has two implementation of L2Norm.
33enum KernelType {
34 kReference,
35 kGenericOptimized,
36};
37
38constexpr int kInputTensor = 0;
39constexpr int kOutputTensor = 0;
40
41TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
42 auto* params = reinterpret_cast<TfLiteL2NormParams*>(node->builtin_data);
43
44 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
45 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
46
47 const TfLiteTensor* input;
48 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
49 TfLiteTensor* output;
50 TF_LITE_ENSURE_OK(context,
51 GetOutputSafe(context, node, kOutputTensor, &output));
52
53 TF_LITE_ENSURE(context, NumDimensions(input) <= 4);
54
55 TF_LITE_ENSURE(context, output->type == kTfLiteFloat32 ||
56 output->type == kTfLiteUInt8 ||
57 output->type == kTfLiteInt8);
58 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
59
60 if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8) {
61 TF_LITE_ENSURE_EQ(context, output->params.scale, (1. / 128.));
62 if (output->type == kTfLiteUInt8) {
63 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 128);
64 }
65 if (output->type == kTfLiteInt8) {
66 TF_LITE_ENSURE_EQ(context, output->params.zero_point, 0);
67 }
68 }
69
70 // TODO(ahentz): For some reason our implementations don't support
71 // activations.
72 TF_LITE_ENSURE_EQ(context, params->activation, kTfLiteActNone);
73
74 TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
75 return context->ResizeTensor(context, output, output_size);
76}
77
78template <KernelType kernel_type>
79TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
80 const TfLiteTensor* input;
81 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
82 TfLiteTensor* output;
83 TF_LITE_ENSURE_OK(context,
84 GetOutputSafe(context, node, kOutputTensor, &output));
85
86 // TODO(b/143912164): instead of hardcode the epsilon here, we should read it
87 // from tensorflow, i.e., adding a params.
88 // We don't compute epsilon for quantized kernel:
89 //
90 // epsilon_float = (epsilon_quant - zp) * scale
91 // so
92 // espsilon_quant = epsilon_float / scale + zp
93 // We know epsilon_float is just a very small number to avoid division by
94 // zero error, and scale is > 1, so the integer value of epsilon for quant
95 // is just dominated by the zero point.
96 // Also, GetInvSqrtQuantizedMultiplierExp handles the scenario where the sum
97 // of input value squared is zero case well.
98 // So we don't even need to do handle the epsilon for quantized kernel case.
99 const float epsilon = 1e-6f;
100 if (output->type == kTfLiteFloat32) {
101#define TF_LITE_L2NORM(type) \
102 tflite::L2NormalizationParams op_params; \
103 op_params.input_zero_point = 0; \
104 type::L2Normalization(op_params, GetTensorShape(input), \
105 GetTensorData<float>(input), GetTensorShape(output), \
106 GetTensorData<float>(output), epsilon)
107
108 if (kernel_type == kReference) {
109 TF_LITE_L2NORM(reference_ops);
110 }
111 if (kernel_type == kGenericOptimized) {
112 TF_LITE_L2NORM(optimized_ops);
113 }
114#undef TF_LITE_L2NORM
115 } else if (output->type == kTfLiteUInt8) {
116#define TF_LITE_L2NORM(type) \
117 tflite::L2NormalizationParams op_params; \
118 op_params.input_zero_point = input->params.zero_point; \
119 type::L2Normalization(op_params, GetTensorShape(input), \
120 GetTensorData<uint8>(input), GetTensorShape(output), \
121 GetTensorData<uint8>(output))
122
123 if (kernel_type == kReference) {
124 TF_LITE_L2NORM(reference_ops);
125 }
126 if (kernel_type == kGenericOptimized) {
127 TF_LITE_L2NORM(optimized_ops);
128 }
129#undef TF_LITE_L2NORM
130 } else if (output->type == kTfLiteInt8) {
131 const auto input_shape = GetTensorShape(input);
132 const auto output_shape = GetTensorShape(output);
133 const int trailing_dim = input_shape.DimensionsCount() - 1;
134 const int depth =
135 MatchingDim(input_shape, trailing_dim, output_shape, trailing_dim);
136 const int outer_size =
137 MatchingFlatSizeSkipDim(input_shape, trailing_dim, output_shape);
138 reference_integer_ops::L2Normalization(input->params.zero_point, outer_size,
139 depth, GetTensorData<int8>(input),
140 GetTensorData<int8>(output));
141 } else {
142 TF_LITE_KERNEL_LOG(context, "Output type is %s, requires float.",
143 TfLiteTypeGetName(output->type));
144 return kTfLiteError;
145 }
146
147 return kTfLiteOk;
148}
149
150} // namespace l2norm
151
152TfLiteRegistration* Register_L2NORM_REF() {
153 static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare,
154 l2norm::Eval<l2norm::kReference>};
155 return &r;
156}
157
158TfLiteRegistration* Register_L2NORM_GENERIC_OPT() {
159 static TfLiteRegistration r = {nullptr, nullptr, l2norm::Prepare,
160 l2norm::Eval<l2norm::kGenericOptimized>};
161 return &r;
162}
163
164TfLiteRegistration* Register_L2_NORMALIZATION() {
165 return Register_L2NORM_GENERIC_OPT();
166}
167
168} // namespace builtin
169} // namespace ops
170} // namespace tflite
171