1 | // Copyright 2021 Google LLC |
2 | // |
3 | // Licensed under the Apache License, Version 2.0 (the "License"); |
4 | // you may not use this file except in compliance with the License. |
5 | // You may obtain a copy of the License at |
6 | // |
7 | // http://www.apache.org/licenses/LICENSE-2.0 |
8 | // |
9 | // Unless required by applicable law or agreed to in writing, software |
10 | // distributed under the License is distributed on an "AS IS" BASIS, |
11 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | // See the License for the specific language governing permissions and |
13 | // limitations under the License. |
14 | |
15 | #include <cmath> |
16 | |
17 | #include "tensorflow/lite/c/common.h" |
18 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
19 | #include "tensorflow/lite/kernels/kernel_util.h" |
20 | |
21 | namespace tflite { |
22 | namespace ops { |
23 | namespace builtin { |
24 | namespace atan2 { |
25 | |
26 | TfLiteStatus EnsureSameShape( |
27 | TfLiteContext* context, |
28 | const TfLiteTensor* a, const TfLiteTensor* b) { |
29 | TF_LITE_ENSURE_EQ(context, |
30 | tflite::NumDimensions(a), |
31 | tflite::NumDimensions(b)); |
32 | |
33 | return TfLiteStatus::kTfLiteOk; |
34 | } |
35 | |
36 | TfLiteStatus Atan2Prepare(TfLiteContext* context, TfLiteNode* node) { |
37 | TF_LITE_ENSURE_EQ(context, tflite::NumInputs(node), 2); |
38 | TF_LITE_ENSURE_EQ(context, tflite::NumOutputs(node), 1); |
39 | |
40 | const TfLiteTensor* input_y = tflite::GetInput(context, node, 0); |
41 | const TfLiteTensor* input_x = tflite::GetInput(context, node, 1); |
42 | TfLiteTensor* output = tflite::GetOutput(context, node, 0); |
43 | |
44 | // Validate size and type constraints |
45 | TF_LITE_ENSURE_OK(context, EnsureSameShape(context, input_y, input_x)); |
46 | TF_LITE_ENSURE_TYPES_EQ(context, input_y->type, input_x->type); |
47 | TF_LITE_ENSURE_TYPES_EQ(context, input_y->type, output->type); |
48 | TF_LITE_ENSURE(context, |
49 | input_y->type == kTfLiteFloat32 || |
50 | input_y->type == kTfLiteFloat64); |
51 | |
52 | TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input_y->dims); |
53 | |
54 | return context->ResizeTensor(context, output, output_shape); |
55 | } |
56 | |
57 | template<typename Float> |
58 | TfLiteStatus Atan2(const TfLiteTensor* input_y, |
59 | const TfLiteTensor* input_x, |
60 | TfLiteTensor* output) { |
61 | const Float* data_y = tflite::GetTensorData<Float>(input_y); |
62 | const Float* data_x = tflite::GetTensorData<Float>(input_x); |
63 | Float* data_output = tflite::GetTensorData<Float>(output); |
64 | |
65 | const int64_t num_elements = NumElements(input_y); |
66 | for (int64_t i = 0; i < num_elements; ++i) { |
67 | data_output[i] = std::atan2(data_y[i], data_x[i]); |
68 | } |
69 | |
70 | return TfLiteStatus::kTfLiteOk; |
71 | } |
72 | |
73 | TfLiteStatus Atan2Eval(TfLiteContext* context, TfLiteNode* node) { |
74 | const TfLiteTensor* input_y = tflite::GetInput(context, node, 0); |
75 | const TfLiteTensor* input_x = tflite::GetInput(context, node, 1); |
76 | TfLiteTensor* output = tflite::GetOutput(context, node, 0); |
77 | |
78 | switch (output->type) { |
79 | case kTfLiteFloat32: |
80 | TF_LITE_ENSURE_OK(context, Atan2<float>(input_y, input_x, output)); |
81 | break; |
82 | case kTfLiteFloat64: |
83 | TF_LITE_ENSURE_OK(context, Atan2<double>(input_y, input_x, output)); |
84 | break; |
85 | default: |
86 | TF_LITE_KERNEL_LOG( |
87 | context, |
88 | "Unsupported datatype for atan2 output: %s" , |
89 | TfLiteTypeGetName(output->type)); |
90 | } |
91 | |
92 | return TfLiteStatus::kTfLiteOk; |
93 | } |
94 | |
95 | } // namespace atan2 |
96 | |
97 | TfLiteRegistration* Register_ATAN2() { |
98 | static TfLiteRegistration r = { |
99 | nullptr, nullptr, atan2::Atan2Prepare, atan2::Atan2Eval}; |
100 | return &r; |
101 | } |
102 | |
103 | } // namespace builtin |
104 | } // namespace ops |
105 | } // namespace tflite |
106 | |