1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stddef.h>
16#include <stdint.h>
17
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace pow {
29namespace {
30
31// Input/output tensor index.
32constexpr int kInputTensor1 = 0;
33constexpr int kInputTensor2 = 1;
34constexpr int kOutputTensor = 0;
35
36// Op data for pow op.
37struct OpData {
38 bool requires_broadcast;
39};
40
41void* Init(TfLiteContext* context, const char* buffer, size_t length) {
42 auto* data = new OpData;
43 data->requires_broadcast = false;
44 return data;
45}
46
47void Free(TfLiteContext* context, void* buffer) {
48 delete reinterpret_cast<OpData*>(buffer);
49}
50
51TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
52 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
53 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
54
55 OpData* data = reinterpret_cast<OpData*>(node->user_data);
56
57 const TfLiteTensor* input1;
58 TF_LITE_ENSURE_OK(context,
59 GetInputSafe(context, node, kInputTensor1, &input1));
60 const TfLiteTensor* input2;
61 TF_LITE_ENSURE_OK(context,
62 GetInputSafe(context, node, kInputTensor2, &input2));
63 TfLiteTensor* output;
64 TF_LITE_ENSURE_OK(context,
65 GetOutputSafe(context, node, kOutputTensor, &output));
66
67 TF_LITE_ENSURE_TYPES_EQ(context, input1->type, input2->type);
68
69 const TfLiteType type = input1->type;
70 if (type != kTfLiteInt32 && type != kTfLiteFloat32) {
71 TF_LITE_KERNEL_LOG(context, "Unsupported data type %s.",
72 TfLiteTypeGetName(type));
73 return kTfLiteError;
74 }
75 output->type = type;
76
77 data->requires_broadcast = !HaveSameShapes(input1, input2);
78
79 TfLiteIntArray* output_size = nullptr;
80 if (data->requires_broadcast) {
81 TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
82 context, input1, input2, &output_size));
83 } else {
84 output_size = TfLiteIntArrayCopy(input1->dims);
85 }
86
87 return context->ResizeTensor(context, output, output_size);
88}
89
90template <typename T>
91void PowImpl(const TfLiteTensor* input1, const TfLiteTensor* input2,
92 TfLiteTensor* output, bool requires_broadcast) {
93 if (requires_broadcast) {
94 optimized_ops::BroadcastPow4D(
95 GetTensorShape(input1), GetTensorData<T>(input1),
96 GetTensorShape(input2), GetTensorData<T>(input2),
97 GetTensorShape(output), GetTensorData<T>(output));
98 } else {
99 reference_ops::Pow(GetTensorShape(input1), GetTensorData<T>(input1),
100 GetTensorShape(input2), GetTensorData<T>(input2),
101 GetTensorShape(output), GetTensorData<T>(output));
102 }
103}
104
105TfLiteStatus CheckValue(TfLiteContext* context, const TfLiteTensor* input) {
106 const int64_t num_elements = NumElements(input);
107 const int32_t* data = GetTensorData<int32_t>(input);
108 for (int i = 0; i < num_elements; ++i) {
109 if (data[i] < 0) {
110 TF_LITE_KERNEL_LOG(context,
111 "POW does not support negative value for int32.");
112 return kTfLiteError;
113 }
114 }
115 return kTfLiteOk;
116}
117
118TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
119 OpData* data = reinterpret_cast<OpData*>(node->user_data);
120
121 const TfLiteTensor* input1;
122 TF_LITE_ENSURE_OK(context,
123 GetInputSafe(context, node, kInputTensor1, &input1));
124 const TfLiteTensor* input2;
125 TF_LITE_ENSURE_OK(context,
126 GetInputSafe(context, node, kInputTensor2, &input2));
127 TfLiteTensor* output;
128 TF_LITE_ENSURE_OK(context,
129 GetOutputSafe(context, node, kOutputTensor, &output));
130
131 switch (output->type) {
132 case kTfLiteInt32: {
133 // TensorFlow does not support negative for int32.
134 TF_LITE_ENSURE_OK(context, CheckValue(context, input2));
135 PowImpl<int32_t>(input1, input2, output, data->requires_broadcast);
136 break;
137 }
138 case kTfLiteFloat32: {
139 PowImpl<float>(input1, input2, output, data->requires_broadcast);
140 break;
141 }
142 default: {
143 TF_LITE_KERNEL_LOG(context, "Unsupported data type: %d", output->type);
144 return kTfLiteError;
145 }
146 }
147 return kTfLiteOk;
148}
149
150} // namespace
151} // namespace pow
152
153TfLiteRegistration* Register_POW() {
154 static TfLiteRegistration r = {pow::Init, pow::Free, pow::Prepare, pow::Eval};
155 return &r;
156}
157
158} // namespace builtin
159} // namespace ops
160} // namespace tflite
161