1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <stdint.h> |
16 | |
17 | #include "tensorflow/lite/c/builtin_op_data.h" |
18 | #include "tensorflow/lite/c/common.h" |
19 | #include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h" |
20 | #include "tensorflow/lite/kernels/internal/reference/reference_ops.h" |
21 | #include "tensorflow/lite/kernels/internal/tensor.h" |
22 | #include "tensorflow/lite/kernels/internal/tensor_ctypes.h" |
23 | #include "tensorflow/lite/kernels/internal/types.h" |
24 | #include "tensorflow/lite/kernels/kernel_util.h" |
25 | |
26 | namespace tflite { |
27 | namespace ops { |
28 | namespace builtin { |
29 | namespace split { |
30 | |
31 | struct OpContext { |
32 | OpContext(TfLiteContext* context, TfLiteNode* node) { |
33 | params = reinterpret_cast<TfLiteSplitParams*>(node->builtin_data); |
34 | axis = GetInput(context, node, 0); |
35 | input = GetInput(context, node, 1); |
36 | } |
37 | TfLiteSplitParams* params; |
38 | const TfLiteTensor* axis; |
39 | const TfLiteTensor* input; |
40 | }; |
41 | |
42 | TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) { |
43 | for (int i = 0; i < NumOutputs(node); ++i) { |
44 | TfLiteTensor* tensor; |
45 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor)); |
46 | SetTensorToDynamic(tensor); |
47 | } |
48 | return kTfLiteOk; |
49 | } |
50 | |
51 | TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node, |
52 | const TfLiteTensor* axis, |
53 | const TfLiteTensor* input, int num_splits) { |
54 | int axis_value = GetTensorData<int>(axis)[0]; |
55 | if (axis_value < 0) { |
56 | axis_value += NumDimensions(input); |
57 | } |
58 | |
59 | TF_LITE_ENSURE(context, axis_value >= 0); |
60 | TF_LITE_ENSURE(context, axis_value < NumDimensions(input)); |
61 | |
62 | const int input_size = SizeOfDimension(input, axis_value); |
63 | TF_LITE_ENSURE(context, num_splits != 0); |
64 | TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0, |
65 | "Not an even split" ); |
66 | const int slice_size = input_size / num_splits; |
67 | |
68 | for (int i = 0; i < NumOutputs(node); ++i) { |
69 | TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims); |
70 | output_dims->data[axis_value] = slice_size; |
71 | TfLiteTensor* output; |
72 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output)); |
73 | TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims)); |
74 | } |
75 | |
76 | return kTfLiteOk; |
77 | } |
78 | |
79 | TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { |
80 | TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); |
81 | |
82 | OpContext op_context(context, node); |
83 | |
84 | TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits); |
85 | |
86 | auto input_type = op_context.input->type; |
87 | TF_LITE_ENSURE(context, |
88 | input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 || |
89 | input_type == kTfLiteInt8 || input_type == kTfLiteInt16 || |
90 | input_type == kTfLiteInt32); |
91 | for (int i = 0; i < NumOutputs(node); ++i) { |
92 | TfLiteTensor* tensor; |
93 | TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor)); |
94 | tensor->type = input_type; |
95 | } |
96 | |
97 | // If we know the contents of the 'axis' tensor, resize all outputs. |
98 | // Otherwise, wait until Eval(). |
99 | if (IsConstantTensor(op_context.axis)) { |
100 | return ResizeOutputTensors(context, node, op_context.axis, op_context.input, |
101 | op_context.params->num_splits); |
102 | } else { |
103 | return UseDynamicOutputTensors(context, node); |
104 | } |
105 | } |
106 | |
107 | TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { |
108 | OpContext op_context(context, node); |
109 | |
110 | // When the 'axis' tensor is non-const we can't resize output tensors in |
111 | // Prepare(), and we have to do it now. |
112 | if (!IsConstantTensor(op_context.axis)) { |
113 | TF_LITE_ENSURE_OK( |
114 | context, |
115 | ResizeOutputTensors(context, node, op_context.axis, op_context.input, |
116 | op_context.params->num_splits)); |
117 | } |
118 | |
119 | int axis_value = GetTensorData<int>(op_context.axis)[0]; |
120 | if (axis_value < 0) { |
121 | axis_value += NumDimensions(op_context.input); |
122 | } |
123 | |
124 | TF_LITE_ENSURE(context, axis_value >= 0); |
125 | TF_LITE_ENSURE(context, axis_value < NumDimensions(op_context.input)); |
126 | |
127 | // TODO(b/173221795): Our usage of VectorOfTensors could be optimized by |
128 | // calculating it in Prepare, unless we defer shape calculation. |
129 | // We can improve the optimized_ops version to handle other |
130 | // cases too. |
131 | #define TF_LITE_SPLIT(scalar) \ |
132 | VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \ |
133 | tflite::SplitParams op_params; \ |
134 | op_params.num_split = NumOutputs(node); \ |
135 | op_params.axis = axis_value; \ |
136 | reference_ops::Split(op_params, GetTensorShape(op_context.input), \ |
137 | GetTensorData<scalar>(op_context.input), \ |
138 | all_outputs.shapes(), all_outputs.data()); |
139 | |
140 | switch (op_context.input->type) { |
141 | case kTfLiteFloat32: { |
142 | TF_LITE_SPLIT(float); |
143 | break; |
144 | } |
145 | case kTfLiteUInt8: { |
146 | TF_LITE_SPLIT(uint8_t); |
147 | break; |
148 | } |
149 | case kTfLiteInt8: { |
150 | TF_LITE_SPLIT(int8_t); |
151 | break; |
152 | } |
153 | case kTfLiteInt16: { |
154 | TF_LITE_SPLIT(int16_t); |
155 | break; |
156 | } |
157 | case kTfLiteInt32: { |
158 | TF_LITE_SPLIT(int32_t); |
159 | break; |
160 | } |
161 | default: |
162 | TF_LITE_KERNEL_LOG(context, "Type %s currently not supported." , |
163 | TfLiteTypeGetName(op_context.input->type)); |
164 | return kTfLiteError; |
165 | } |
166 | #undef TF_LITE_SPLIT |
167 | |
168 | return kTfLiteOk; |
169 | } |
170 | |
171 | } // namespace split |
172 | |
173 | TfLiteRegistration* Register_SPLIT() { |
174 | static TfLiteRegistration r = {nullptr, nullptr, split::Prepare, split::Eval}; |
175 | return &r; |
176 | } |
177 | |
178 | } // namespace builtin |
179 | } // namespace ops |
180 | } // namespace tflite |
181 | |