1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
20#include "tensorflow/lite/kernels/internal/tensor.h"
21#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
22#include "tensorflow/lite/kernels/kernel_util.h"
23
24namespace tflite {
25namespace ops {
26namespace builtin {
27namespace reverse {
28namespace {
29
30constexpr int kInputTensor = 0;
31constexpr int kAxisTensor = 1;
32constexpr int kOutputTensor = 0;
33
34TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
35 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
36 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
37
38 const TfLiteTensor* input;
39 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
40 const TfLiteTensor* axis;
41 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kAxisTensor, &axis));
42 TF_LITE_ENSURE_EQ(context, NumDimensions(axis), 1);
43 TF_LITE_ENSURE(context, NumDimensions(input) >= NumElements(axis));
44
45 if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
46 input->type != kTfLiteUInt8 && input->type != kTfLiteInt8 &&
47 input->type != kTfLiteInt16 && input->type != kTfLiteInt64 &&
48 input->type != kTfLiteBool) {
49 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by reverse.",
50 TfLiteTypeGetName(input->type));
51 return kTfLiteError;
52 }
53
54 if (axis->type != kTfLiteInt32) {
55 TF_LITE_KERNEL_LOG(context, "Axis Type '%s' is not supported by reverse.",
56 TfLiteTypeGetName(axis->type));
57 return kTfLiteError;
58 }
59
60 // TODO(b/186320180): support multi-axis case.
61 if (NumElements(axis) > 1) {
62 TF_LITE_KERNEL_LOG(context, "Current does not support more than 1 axis.");
63 }
64
65 TfLiteTensor* output;
66 TF_LITE_ENSURE_OK(context,
67 GetOutputSafe(context, node, kOutputTensor, &output));
68 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
69 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
70
71 return context->ResizeTensor(context, output, output_shape);
72}
73
74TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
75 const TfLiteTensor* input;
76 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
77 const TfLiteTensor* axis_tensor;
78 TF_LITE_ENSURE_OK(context,
79 GetInputSafe(context, node, kAxisTensor, &axis_tensor));
80 int axis = GetTensorData<int32_t>(axis_tensor)[0];
81 const int rank = NumDimensions(input);
82 if (axis < 0) {
83 axis += rank;
84 }
85
86 TF_LITE_ENSURE(context, axis >= 0 && axis < rank);
87 TfLiteTensor* output;
88 TF_LITE_ENSURE_OK(context,
89 GetOutputSafe(context, node, kOutputTensor, &output));
90
91 switch (output->type) {
92 case kTfLiteFloat32: {
93 reference_ops::Reverse<float>(
94 axis, GetTensorShape(input), GetTensorData<float>(input),
95 GetTensorShape(output), GetTensorData<float>(output));
96 break;
97 }
98 case kTfLiteUInt8:
99 case kTfLiteInt8: {
100 reference_ops::Reverse<uint8_t>(
101 axis, GetTensorShape(input), GetTensorData<uint8_t>(input),
102 GetTensorShape(output), GetTensorData<uint8_t>(output));
103 break;
104 }
105 case kTfLiteInt16: {
106 reference_ops::Reverse<int16_t>(
107 axis, GetTensorShape(input), GetTensorData<int16_t>(input),
108 GetTensorShape(output), GetTensorData<int16_t>(output));
109 break;
110 }
111 case kTfLiteInt32: {
112 reference_ops::Reverse<int32_t>(
113 axis, GetTensorShape(input), GetTensorData<int32_t>(input),
114 GetTensorShape(output), GetTensorData<int32_t>(output));
115 break;
116 }
117 case kTfLiteInt64: {
118 reference_ops::Reverse<int64_t>(
119 axis, GetTensorShape(input), GetTensorData<int64_t>(input),
120 GetTensorShape(output), GetTensorData<int64_t>(output));
121 break;
122 }
123 case kTfLiteBool: {
124 reference_ops::Reverse<bool>(
125 axis, GetTensorShape(input), GetTensorData<bool>(input),
126 GetTensorShape(output), GetTensorData<bool>(output));
127 break;
128 }
129 default: {
130 TF_LITE_KERNEL_LOG(context, "Type '%s' is not supported by reverse.",
131 TfLiteTypeGetName(output->type));
132 return kTfLiteError;
133 }
134 }
135
136 return kTfLiteOk;
137}
138
139} // namespace
140} // namespace reverse
141
142TfLiteRegistration* Register_REVERSE_V2() {
143 static TfLiteRegistration r = {nullptr, nullptr, reverse::Prepare,
144 reverse::Eval};
145 return &r;
146}
147
148} // namespace builtin
149} // namespace ops
150} // namespace tflite
151