1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <stdint.h>
17
18#include "tensorflow/lite/c/builtin_op_data.h"
19#include "tensorflow/lite/c/common.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/kernel_util.h"
24
25namespace tflite {
26namespace ops {
27namespace builtin {
28namespace reverse_sequence {
29namespace {
30
31constexpr int kInputTensor = 0;
32constexpr int kSeqLengthsTensor = 1;
33constexpr int kOutputTensor = 0;
34
35TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
36 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
37 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
38
39 const TfLiteTensor* input;
40 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
41 const TfLiteTensor* seq_lengths;
42 TF_LITE_ENSURE_OK(
43 context, GetInputSafe(context, node, kSeqLengthsTensor, &seq_lengths));
44 TF_LITE_ENSURE_EQ(context, NumDimensions(seq_lengths), 1);
45
46 if (input->type != kTfLiteInt32 && input->type != kTfLiteFloat32 &&
47 input->type != kTfLiteUInt8 && input->type != kTfLiteInt16 &&
48 input->type != kTfLiteInt64) {
49 TF_LITE_KERNEL_LOG(context,
50 "Type '%s' is not supported by reverse_sequence.",
51 TfLiteTypeGetName(input->type));
52 return kTfLiteError;
53 }
54
55 if (seq_lengths->type != kTfLiteInt32 && seq_lengths->type != kTfLiteInt64) {
56 TF_LITE_KERNEL_LOG(
57 context, "Seq_lengths type '%s' is not supported by reverse_sequence.",
58 TfLiteTypeGetName(seq_lengths->type));
59 return kTfLiteError;
60 }
61
62 TfLiteTensor* output;
63 TF_LITE_ENSURE_OK(context,
64 GetOutputSafe(context, node, kOutputTensor, &output));
65 TfLiteIntArray* output_shape = TfLiteIntArrayCopy(input->dims);
66 TF_LITE_ENSURE_TYPES_EQ(context, output->type, input->type);
67
68 return context->ResizeTensor(context, output, output_shape);
69}
70
71template <typename T, typename TS>
72TfLiteStatus ReverseSequenceImpl(TfLiteContext* context, TfLiteNode* node) {
73 const TfLiteTensor* input;
74 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInputTensor, &input));
75 const TfLiteTensor* seq_lengths_tensor;
76 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
77 &seq_lengths_tensor));
78 const TS* seq_lengths = GetTensorData<TS>(seq_lengths_tensor);
79
80 auto* params =
81 reinterpret_cast<TfLiteReverseSequenceParams*>(node->builtin_data);
82 int seq_dim = params->seq_dim;
83 int batch_dim = params->batch_dim;
84
85 TF_LITE_ENSURE(context, seq_dim >= 0);
86 TF_LITE_ENSURE(context, batch_dim >= 0);
87 TF_LITE_ENSURE(context, seq_dim != batch_dim);
88 TF_LITE_ENSURE(context, seq_dim < NumDimensions(input));
89 TF_LITE_ENSURE(context, batch_dim < NumDimensions(input));
90 TF_LITE_ENSURE_EQ(context, SizeOfDimension(seq_lengths_tensor, 0),
91 SizeOfDimension(input, batch_dim));
92 for (int i = 0; i < NumDimensions(seq_lengths_tensor); ++i) {
93 TF_LITE_ENSURE(context, seq_lengths[i] <= SizeOfDimension(input, seq_dim));
94 }
95
96 TfLiteTensor* output;
97 TF_LITE_ENSURE_OK(context,
98 GetOutputSafe(context, node, kOutputTensor, &output));
99
100 reference_ops::ReverseSequence<T, TS>(
101 seq_lengths, seq_dim, batch_dim, GetTensorShape(input),
102 GetTensorData<T>(input), GetTensorShape(output),
103 GetTensorData<T>(output));
104
105 return kTfLiteOk;
106}
107
108template <typename T>
109TfLiteStatus ReverseSequenceHelper(TfLiteContext* context, TfLiteNode* node) {
110 const TfLiteTensor* seq_lengths_tensor;
111 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kSeqLengthsTensor,
112 &seq_lengths_tensor));
113 switch (seq_lengths_tensor->type) {
114 case kTfLiteInt32: {
115 return ReverseSequenceImpl<T, int32_t>(context, node);
116 }
117 case kTfLiteInt64: {
118 return ReverseSequenceImpl<T, int64_t>(context, node);
119 }
120 default: {
121 TF_LITE_KERNEL_LOG(
122 context,
123 "Seq_lengths type '%s' is not supported by reverse_sequence.",
124 TfLiteTypeGetName(seq_lengths_tensor->type));
125 return kTfLiteError;
126 }
127 }
128 return kTfLiteOk;
129}
130
131TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
132 TfLiteTensor* output;
133 TF_LITE_ENSURE_OK(context,
134 GetOutputSafe(context, node, kOutputTensor, &output));
135
136 switch (output->type) {
137 case kTfLiteFloat32: {
138 return ReverseSequenceHelper<float>(context, node);
139 }
140 case kTfLiteUInt8: {
141 return ReverseSequenceHelper<uint8_t>(context, node);
142 }
143 case kTfLiteInt16: {
144 return ReverseSequenceHelper<int16_t>(context, node);
145 }
146 case kTfLiteInt32: {
147 return ReverseSequenceHelper<int32_t>(context, node);
148 }
149 case kTfLiteInt64: {
150 return ReverseSequenceHelper<int64_t>(context, node);
151 }
152 default: {
153 TF_LITE_KERNEL_LOG(context,
154 "Type '%s' is not supported by reverse_sequence.",
155 TfLiteTypeGetName(output->type));
156 return kTfLiteError;
157 }
158 }
159} // namespace
160
161} // namespace
162} // namespace reverse_sequence
163
164TfLiteRegistration* Register_REVERSE_SEQUENCE() {
165 static TfLiteRegistration r = {nullptr, nullptr, reverse_sequence::Prepare,
166 reverse_sequence::Eval};
167 return &r;
168}
169
170} // namespace builtin
171} // namespace ops
172} // namespace tflite
173