1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <cmath>
16
17#include "tensorflow/lite/c/common.h"
18#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
19#include "tensorflow/lite/kernels/kernel_util.h"
20
21namespace tflite {
22namespace ops {
23namespace builtin {
24namespace sign {
25
26// Performs common preparation for pointwise, unary ops, i.e., type checks and
27// output tensor resizing.
28TfLiteStatus PointwiseUnaryOpPrepare(TfLiteContext* context, TfLiteNode* node) {
29 TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 1);
30
31 const TfLiteTensor* input = tflite::GetInput(context, node, 0);
32 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
33
34 // Validate size and type constraints
35 TF_LITE_ENSURE_TYPES_EQ(context, input->type, output->type);
36 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
37 return context->ResizeTensor(context, output, output_shape);
38}
39
40// Applies the operator Op pointwise to data of type T.
41template <typename Op, typename T>
42TfLiteStatus PointwiseUnaryOpDoEval(
43 TfLiteContext* context,
44 const TfLiteTensor* input,
45 TfLiteTensor* output) {
46 const T* data = tflite::GetTensorData<T>(input);
47 T* data_output = tflite::GetTensorData<T>(output);
48
49 const int64_t num_elements = NumElements(input);
50 for (int64_t i = 0; i < num_elements; ++i) {
51 data_output[i] = Op::template Eval<T>(data[i]);
52 }
53
54 return TfLiteStatus::kTfLiteOk;
55}
56
57// A generic evaluation function where the actual data processing is handled
58// by the Op::Eval<T> function.
59template <typename Op>
60TfLiteStatus PointwiseUnaryOpEval(TfLiteContext* context, TfLiteNode* node) {
61 const TfLiteTensor* input = tflite::GetInput(context, node, 0);
62 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
63
64 switch (output->type) {
65 case kTfLiteFloat32:
66 TF_LITE_ENSURE_OK(
67 context,
68 (PointwiseUnaryOpDoEval<Op, float>(context, input, output)));
69 break;
70 case kTfLiteFloat64:
71 TF_LITE_ENSURE_OK(
72 context,
73 (PointwiseUnaryOpDoEval<Op, double>(context, input, output)));
74 break;
75 default:
76 TF_LITE_KERNEL_LOG(context, "Unsupported datatype for sign output: %s",
77 TfLiteTypeGetName(output->type));
78 }
79
80 return TfLiteStatus::kTfLiteOk;
81}
82
83// Operator that computes the sign function.
84struct Sign {
85 template <typename T>
86 static T Eval(T x) {
87 if (x > 0) {
88 return 1;
89 }
90 if (x < 0) {
91 return -1;
92 }
93 return 0;
94 }
95};
96
97} // namespace sign
98
99TfLiteRegistration* Register_SIGN() {
100 static TfLiteRegistration r = {nullptr, nullptr,
101 sign::PointwiseUnaryOpPrepare,
102 sign::PointwiseUnaryOpEval<sign::Sign>};
103 return &r;
104}
105
106} // namespace builtin
107} // namespace ops
108} // namespace tflite
109