1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include <cmath>
16
17#include "tensorflow/lite/c/common.h"
18#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
19#include "tensorflow/lite/kernels/kernel_util.h"
20
21namespace tflite {
22namespace ops {
23namespace builtin {
24namespace atan2 {
25
26TfLiteStatus EnsureSameShape(
27 TfLiteContext* context,
28 const TfLiteTensor* a, const TfLiteTensor* b) {
29 TF_LITE_ENSURE_EQ(context,
30 tflite::NumDimensions(a),
31 tflite::NumDimensions(b));
32
33 return TfLiteStatus::kTfLiteOk;
34}
35
36TfLiteStatus Atan2Prepare(TfLiteContext* context, TfLiteNode* node) {
37 TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2);
38 TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1);
39
40 const TfLiteTensor* input_y = tflite::GetInput(context, node, 0);
41 const TfLiteTensor* input_x = tflite::GetInput(context, node, 1);
42 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
43
44 // Validate size and type constraints
45 TF_LITE_ENSURE_OK(context, EnsureSameShape(context, input_y, input_x));
46 TF_LITE_ENSURE_TYPES_EQ(context, input_y->type, input_x->type);
47 TF_LITE_ENSURE_TYPES_EQ(context, input_y->type, output->type);
48 TF_LITE_ENSURE(context,
49 input_y->type == kTfLiteFloat32 ||
50 input_y->type == kTfLiteFloat64);
51
52 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input_y->dims);
53
54 return context->ResizeTensor(context, output, output_shape);
55}
56
57template<typename Float>
58TfLiteStatus Atan2(const TfLiteTensor* input_y,
59 const TfLiteTensor* input_x,
60 TfLiteTensor* output) {
61 const Float* data_y = tflite::GetTensorData<Float>(input_y);
62 const Float* data_x = tflite::GetTensorData<Float>(input_x);
63 Float* data_output = tflite::GetTensorData<Float>(output);
64
65 const int64_t num_elements = NumElements(input_y);
66 for (int64_t i = 0; i < num_elements; ++i) {
67 data_output[i] = std::atan2(data_y[i], data_x[i]);
68 }
69
70 return TfLiteStatus::kTfLiteOk;
71}
72
73TfLiteStatus Atan2Eval(TfLiteContext* context, TfLiteNode* node) {
74 const TfLiteTensor* input_y = tflite::GetInput(context, node, 0);
75 const TfLiteTensor* input_x = tflite::GetInput(context, node, 1);
76 TfLiteTensor* output = tflite::GetOutput(context, node, 0);
77
78 switch (output->type) {
79 case kTfLiteFloat32:
80 TF_LITE_ENSURE_OK(context, Atan2<float>(input_y, input_x, output));
81 break;
82 case kTfLiteFloat64:
83 TF_LITE_ENSURE_OK(context, Atan2<double>(input_y, input_x, output));
84 break;
85 default:
86 TF_LITE_KERNEL_LOG(
87 context,
88 "Unsupported datatype for atan2 output: %s",
89 TfLiteTypeGetName(output->type));
90 }
91
92 return TfLiteStatus::kTfLiteOk;
93}
94
95} // namespace atan2
96
97TfLiteRegistration* Register_ATAN2() {
98 static TfLiteRegistration r = {
99 nullptr, nullptr, atan2::Atan2Prepare, atan2::Atan2Eval};
100 return &r;
101}
102
103} // namespace builtin
104} // namespace ops
105} // namespace tflite
106