1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include <vector> |
18 | |
19 | #include "tensorflow/lite/c/builtin_op_data.h" |
20 | #include "tensorflow/lite/c/common.h" |
21 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
22 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
23 | #include "tensorflow/lite/kernels/internal/tensor.h" |
24 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
25 | #include "tensorflow/lite/kernels/internal/types.h" |
26 | #include "tensorflow/lite/kernels/kernel_util.h" |
27 | |
28 | namespace tflite { |
29 | namespace ops { |
30 | namespace builtin { |
31 | namespace split_v { |
32 | |
33 | struct OpContext { |
34 | OpContext(TfLiteContext* context, TfLiteNode* node) { |
35 | params = reinterpret_cast<TfLiteSplitVParams*>(node->builtin_data); |
36 | input = GetInput(context, node, 0); |
37 | size_splits = GetInput(context, node, 1); |
38 | axis = GetInput(context, node, 2); |
39 | } |
40 | TfLiteSplitVParams* params; |
41 | const TfLiteTensor* input; |
42 | const TfLiteTensor* size_splits; |
43 | const TfLiteTensor* axis; |
44 | }; |
45 | |
46 | TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) { |
47 | for (int i = 0; i < NumOutputs(node); ++i) { |
48 | TfLiteTensor* tensor; |
49 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor)); |
50 | SetTensorToDynamic(tensor); |
51 | } |
52 | return kTfLiteOk; |
53 | } |
54 | |
55 | template <typename T> |
56 | void GetSizeSplitsVector(const TfLiteTensor* size_splits, |
57 | std::vector<int64_t>* size_splits_vector) { |
58 | const auto num_elements = NumElements(size_splits); |
59 | for (int i = 0; i < num_elements; ++i) { |
60 | size_splits_vector->push_back(GetTensorData<T>(size_splits)[i]); |
61 | } |
62 | } |
63 | |
64 | TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node, |
65 | const TfLiteTensor* input, |
66 | const TfLiteTensor* size_splits, |
67 | const TfLiteTensor* axis) { |
68 | int axis_value = GetTensorData<int>(axis)[0]; |
69 | if (axis_value < 0) { |
70 | axis_value += NumDimensions(input); |
71 | } |
72 | |
73 | std::vector<int64_t> size_splits_vector; |
74 | if (size_splits->type == kTfLiteInt32) { |
75 | GetSizeSplitsVector<int32_t>(size_splits, &size_splits_vector); |
76 | } else if (size_splits->type == kTfLiteInt64) { |
77 | GetSizeSplitsVector<int64_t>(size_splits, &size_splits_vector); |
78 | } else { |
79 | TF_LITE_KERNEL_LOG(context, "size_splits only support type int32|int64." ); |
80 | return kTfLiteError; |
81 | } |
82 | |
83 | int minus_one_index = -1; |
84 | int64_t size_splits_sum = 0; |
85 | |
86 | for (int i = 0; i < size_splits_vector.size(); ++i) { |
87 | if (size_splits_vector.at(i) == -1) { |
88 | if (minus_one_index == -1) { |
89 | minus_one_index = i; |
90 | } else { |
91 | TF_LITE_KERNEL_LOG(context, |
92 | "The size_splits contains more than one -1." ); |
93 | return kTfLiteError; |
94 | } |
95 | } else { |
96 | size_splits_sum += size_splits_vector.at(i); |
97 | } |
98 | } |
99 | |
100 | TF_LITE_ENSURE(context, axis_value >= 0); |
101 | TF_LITE_ENSURE(context, axis_value < NumDimensions(input)); |
102 | const int input_size = SizeOfDimension(input, axis_value); |
103 | |
104 | if (minus_one_index != -1) { |
105 | if (size_splits_sum > input_size) { |
106 | TF_LITE_KERNEL_LOG( |
107 | context, |
108 | "The sum of size_splits must be less than the dimension of value." ); |
109 | } else { |
110 | size_splits_vector[minus_one_index] = input_size - size_splits_sum; |
111 | } |
112 | } else if (size_splits_sum != input_size) { |
113 | TF_LITE_KERNEL_LOG( |
114 | context, |
115 | "The size_splits must sum to the dimension of value along axis." ); |
116 | } |
117 | |
118 | for (int i = 0; i < NumOutputs(node); ++i) { |
119 | TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims); |
120 | output_dims->data[axis_value] = size_splits_vector.at(i); |
121 | TfLiteTensor* output; |
122 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
123 | TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims)); |
124 | } |
125 | |
126 | return kTfLiteOk; |
127 | } |
128 | |
129 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
130 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 3); |
131 | |
132 | OpContext op_context(context, node); |
133 | |
134 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); |
135 | |
136 | auto input_type = op_context.input->type; |
137 | TF_LITE_ENSURE(context, |
138 | input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || |
139 | input_type == kTfLiteInt16 || input_type == kTfLiteInt32 || |
140 | input_type == kTfLiteInt64 || input_type == kTfLiteInt8); |
141 | for (int i = 0; i < NumOutputs(node); ++i) { |
142 | TfLiteTensor* tensor; |
143 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor)); |
144 | tensor->type = input_type; |
145 | } |
146 | |
147 | auto size_splits = op_context.size_splits; |
148 | TF_LITE_ENSURE_EQ(context, NumDimensions(size_splits), 1); |
149 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), NumElements(size_splits)); |
150 | |
151 | // If we know the contents of the 'size_splits' tensor and the 'axis' tensor, |
152 | // resize all outputs. Otherwise, wait until Eval(). |
153 | if (IsConstantTensor(op_context.size_splits) && |
154 | IsConstantTensor(op_context.axis)) { |
155 | return ResizeOutputTensors(context, node, op_context.input, |
156 | op_context.size_splits, op_context.axis); |
157 | } else { |
158 | return UseDynamicOutputTensors(context, node); |
159 | } |
160 | } |
161 | |
162 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
163 | OpContext op_context(context, node); |
164 | |
165 | // When the 'size_splits' and the 'axis' tensor is non-const we can't resize |
166 | // output tensors in Prepare(), and we have to do it now. |
167 | if (!IsConstantTensor(op_context.axis) || |
168 | !IsConstantTensor(op_context.size_splits)) { |
169 | TF_LITE_ENSURE_OK( |
170 | context, ResizeOutputTensors(context, node, op_context.input, |
171 | op_context.size_splits, op_context.axis)); |
172 | } |
173 | |
174 | int axis_value = GetTensorData<int>(op_context.axis)[0]; |
175 | |
176 | // Use split function to build the outputs since they share the same logic. |
177 | #define TF_LITE_SPLIT_V(scalar) \ |
178 | VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \ |
179 | tflite::SplitParams op_params; \ |
180 | op_params.num_split = NumOutputs(node); \ |
181 | op_params.axis = axis_value; \ |
182 | reference_ops::Split(op_params, GetTensorShape(op_context.input), \ |
183 | GetTensorData<scalar>(op_context.input), \ |
184 | all_outputs.shapes(), all_outputs.data()); |
185 | switch (op_context.input->type) { |
186 | case kTfLiteFloat32: { |
187 | TF_LITE_SPLIT_V(float); |
188 | break; |
189 | } |
190 | case kTfLiteUInt8: { |
191 | TF_LITE_SPLIT_V(uint8_t); |
192 | break; |
193 | } |
194 | case kTfLiteInt16: { |
195 | TF_LITE_SPLIT_V(int16_t); |
196 | break; |
197 | } |
198 | case kTfLiteInt32: { |
199 | TF_LITE_SPLIT_V(int32_t); |
200 | break; |
201 | } |
202 | case kTfLiteInt64: { |
203 | TF_LITE_SPLIT_V(int64_t); |
204 | break; |
205 | } |
206 | case kTfLiteInt8: { |
207 | TF_LITE_SPLIT_V(int8_t); |
208 | break; |
209 | } |
210 | default: |
211 | TF_LITE_KERNEL_LOG(context, "Type %s currently not supported." , |
212 | TfLiteTypeGetName(op_context.input->type)); |
213 | return kTfLiteError; |
214 | } |
215 | #undef TF_LITE_SPLIT_V |
216 | |
217 | return kTfLiteOk; |
218 | } |
219 | |
220 | } // namespace split_v |
221 | |
222 | TfLiteRegistration* Register_SPLIT_V() { |
223 | static TfLiteRegistration r = {nullptr, nullptr, split_v::Prepare, |
224 | split_v::Eval}; |
225 | return &r; |
226 | } |
227 | |
228 | } // namespace builtin |
229 | } // namespace ops |
230 | } // namespace tflite |
231 | |