1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include <vector>
18
19#include "tensorflow/lite/c/builtin_op_data.h"
20#include "tensorflow/lite/c/common.h"
21#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
22#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
23#include "tensorflow/lite/kernels/internal/tensor.h"
24#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
25#include "tensorflow/lite/kernels/internal/types.h"
26#include "tensorflow/lite/kernels/kernel_util.h"
27
28namespace tflite {
29namespace ops {
30namespace builtin {
31namespace split_v {
32
33struct OpContext {
34 OpContext(TfLiteContext* context, TfLiteNode* node) {
35 params = reinterpret_cast<TfLiteSplitVParams*>(node->builtin_data);
36 input = GetInput(context, node, 0);
37 size_splits = GetInput(context, node, 1);
38 axis = GetInput(context, node, 2);
39 }
40 TfLiteSplitVParams* params;
41 const TfLiteTensor* input;
42 const TfLiteTensor* size_splits;
43 const TfLiteTensor* axis;
44};
45
46TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
47 for (int i = 0; i < NumOutputs(node); ++i) {
48 TfLiteTensor* tensor;
49 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
50 SetTensorToDynamic(tensor);
51 }
52 return kTfLiteOk;
53}
54
55template <typename T>
56void GetSizeSplitsVector(const TfLiteTensor* size_splits,
57 std::vector<int64_t>* size_splits_vector) {
58 const auto num_elements = NumElements(size_splits);
59 for (int i = 0; i < num_elements; ++i) {
60 size_splits_vector->push_back(GetTensorData<T>(size_splits)[i]);
61 }
62}
63
64TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
65 const TfLiteTensor* input,
66 const TfLiteTensor* size_splits,
67 const TfLiteTensor* axis) {
68 int axis_value = GetTensorData<int>(axis)[0];
69 if (axis_value < 0) {
70 axis_value += NumDimensions(input);
71 }
72
73 std::vector<int64_t> size_splits_vector;
74 if (size_splits->type == kTfLiteInt32) {
75 GetSizeSplitsVector<int32_t>(size_splits, &size_splits_vector);
76 } else if (size_splits->type == kTfLiteInt64) {
77 GetSizeSplitsVector<int64_t>(size_splits, &size_splits_vector);
78 } else {
79 TF_LITE_KERNEL_LOG(context, "size_splits only support type int32|int64.");
80 return kTfLiteError;
81 }
82
83 int minus_one_index = -1;
84 int64_t size_splits_sum = 0;
85
86 for (int i = 0; i < size_splits_vector.size(); ++i) {
87 if (size_splits_vector.at(i) == -1) {
88 if (minus_one_index == -1) {
89 minus_one_index = i;
90 } else {
91 TF_LITE_KERNEL_LOG(context,
92 "The size_splits contains more than one -1.");
93 return kTfLiteError;
94 }
95 } else {
96 size_splits_sum += size_splits_vector.at(i);
97 }
98 }
99
100 TF_LITE_ENSURE(context, axis_value >= 0);
101 TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
102 const int input_size = SizeOfDimension(input, axis_value);
103
104 if (minus_one_index != -1) {
105 if (size_splits_sum > input_size) {
106 TF_LITE_KERNEL_LOG(
107 context,
108 "The sum of size_splits must be less than the dimension of value.");
109 } else {
110 size_splits_vector[minus_one_index] = input_size - size_splits_sum;
111 }
112 } else if (size_splits_sum != input_size) {
113 TF_LITE_KERNEL_LOG(
114 context,
115 "The size_splits must sum to the dimension of value along axis.");
116 }
117
118 for (int i = 0; i < NumOutputs(node); ++i) {
119 TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
120 output_dims->data[axis_value] = size_splits_vector.at(i);
121 TfLiteTensor* output;
122 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
123 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
124 }
125
126 return kTfLiteOk;
127}
128
129TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
130 TF_LITE_ENSURE_EQ(context, NumInputs(node), 3);
131
132 OpContext op_context(context, node);
133
134 TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
135
136 auto input_type = op_context.input->type;
137 TF_LITE_ENSURE(context,
138 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
139 input_type == kTfLiteInt16 || input_type == kTfLiteInt32 ||
140 input_type == kTfLiteInt64 || input_type == kTfLiteInt8);
141 for (int i = 0; i < NumOutputs(node); ++i) {
142 TfLiteTensor* tensor;
143 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
144 tensor->type = input_type;
145 }
146
147 auto size_splits = op_context.size_splits;
148 TF_LITE_ENSURE_EQ(context, NumDimensions(size_splits), 1);
149 TF_LITE_ENSURE_EQ(context, NumOutputs(node), NumElements(size_splits));
150
151 // If we know the contents of the 'size_splits' tensor and the 'axis' tensor,
152 // resize all outputs. Otherwise, wait until Eval().
153 if (IsConstantTensor(op_context.size_splits) &&
154 IsConstantTensor(op_context.axis)) {
155 return ResizeOutputTensors(context, node, op_context.input,
156 op_context.size_splits, op_context.axis);
157 } else {
158 return UseDynamicOutputTensors(context, node);
159 }
160}
161
162TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
163 OpContext op_context(context, node);
164
165 // When the 'size_splits' and the 'axis' tensor is non-const we can't resize
166 // output tensors in Prepare(), and we have to do it now.
167 if (!IsConstantTensor(op_context.axis) ||
168 !IsConstantTensor(op_context.size_splits)) {
169 TF_LITE_ENSURE_OK(
170 context, ResizeOutputTensors(context, node, op_context.input,
171 op_context.size_splits, op_context.axis));
172 }
173
174 int axis_value = GetTensorData<int>(op_context.axis)[0];
175
176 // Use split function to build the outputs since they share the same logic.
177#define TF_LITE_SPLIT_V(scalar) \
178 VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
179 tflite::SplitParams op_params; \
180 op_params.num_split = NumOutputs(node); \
181 op_params.axis = axis_value; \
182 reference_ops::Split(op_params, GetTensorShape(op_context.input), \
183 GetTensorData<scalar>(op_context.input), \
184 all_outputs.shapes(), all_outputs.data());
185 switch (op_context.input->type) {
186 case kTfLiteFloat32: {
187 TF_LITE_SPLIT_V(float);
188 break;
189 }
190 case kTfLiteUInt8: {
191 TF_LITE_SPLIT_V(uint8_t);
192 break;
193 }
194 case kTfLiteInt16: {
195 TF_LITE_SPLIT_V(int16_t);
196 break;
197 }
198 case kTfLiteInt32: {
199 TF_LITE_SPLIT_V(int32_t);
200 break;
201 }
202 case kTfLiteInt64: {
203 TF_LITE_SPLIT_V(int64_t);
204 break;
205 }
206 case kTfLiteInt8: {
207 TF_LITE_SPLIT_V(int8_t);
208 break;
209 }
210 default:
211 TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
212 TfLiteTypeGetName(op_context.input->type));
213 return kTfLiteError;
214 }
215#undef TF_LITE_SPLIT_V
216
217 return kTfLiteOk;
218}
219
220} // namespace split_v
221
222TfLiteRegistration* Register_SPLIT_V() {
223 static TfLiteRegistration r = {nullptr, nullptr, split_v::Prepare,
224 split_v::Eval};
225 return &r;
226}
227
228} // namespace builtin
229} // namespace ops
230} // namespace tflite
231