1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <stdint.h>
16
17#include "tensorflow/lite/c/builtin_op_data.h"
18#include "tensorflow/lite/c/common.h"
19#include "tensorflow/lite/kernels/internal/optimized/optimized_ops.h"
20#include "tensorflow/lite/kernels/internal/reference/reference_ops.h"
21#include "tensorflow/lite/kernels/internal/tensor.h"
22#include "tensorflow/lite/kernels/internal/tensor_ctypes.h"
23#include "tensorflow/lite/kernels/internal/types.h"
24#include "tensorflow/lite/kernels/kernel_util.h"
25
26namespace tflite {
27namespace ops {
28namespace builtin {
29namespace split {
30
31struct OpContext {
32 OpContext(TfLiteContext* context, TfLiteNode* node) {
33 params = reinterpret_cast<TfLiteSplitParams*>(node->builtin_data);
34 axis = GetInput(context, node, 0);
35 input = GetInput(context, node, 1);
36 }
37 TfLiteSplitParams* params;
38 const TfLiteTensor* axis;
39 const TfLiteTensor* input;
40};
41
42TfLiteStatus UseDynamicOutputTensors(TfLiteContext* context, TfLiteNode* node) {
43 for (int i = 0; i < NumOutputs(node); ++i) {
44 TfLiteTensor* tensor;
45 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
46 SetTensorToDynamic(tensor);
47 }
48 return kTfLiteOk;
49}
50
51TfLiteStatus ResizeOutputTensors(TfLiteContext* context, TfLiteNode* node,
52 const TfLiteTensor* axis,
53 const TfLiteTensor* input, int num_splits) {
54 int axis_value = GetTensorData<int>(axis)[0];
55 if (axis_value < 0) {
56 axis_value += NumDimensions(input);
57 }
58
59 TF_LITE_ENSURE(context, axis_value >= 0);
60 TF_LITE_ENSURE(context, axis_value < NumDimensions(input));
61
62 const int input_size = SizeOfDimension(input, axis_value);
63 TF_LITE_ENSURE(context, num_splits != 0);
64 TF_LITE_ENSURE_MSG(context, input_size % num_splits == 0,
65 "Not an even split");
66 const int slice_size = input_size / num_splits;
67
68 for (int i = 0; i < NumOutputs(node); ++i) {
69 TfLiteIntArray* output_dims = TfLiteIntArrayCopy(input->dims);
70 output_dims->data[axis_value] = slice_size;
71 TfLiteTensor* output;
72 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &output));
73 TF_LITE_ENSURE_STATUS(context->ResizeTensor(context, output, output_dims));
74 }
75
76 return kTfLiteOk;
77}
78
79TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
80 TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
81
82 OpContext op_context(context, node);
83
84 TF_LITE_ENSURE_EQ(context, NumOutputs(node), op_context.params->num_splits);
85
86 auto input_type = op_context.input->type;
87 TF_LITE_ENSURE(context,
88 input_type == kTfLiteFloat32 || input_type == kTfLiteUInt8 ||
89 input_type == kTfLiteInt8 || input_type == kTfLiteInt16 ||
90 input_type == kTfLiteInt32);
91 for (int i = 0; i < NumOutputs(node); ++i) {
92 TfLiteTensor* tensor;
93 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, i, &tensor));
94 tensor->type = input_type;
95 }
96
97 // If we know the contents of the 'axis' tensor, resize all outputs.
98 // Otherwise, wait until Eval().
99 if (IsConstantTensor(op_context.axis)) {
100 return ResizeOutputTensors(context, node, op_context.axis, op_context.input,
101 op_context.params->num_splits);
102 } else {
103 return UseDynamicOutputTensors(context, node);
104 }
105}
106
107TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
108 OpContext op_context(context, node);
109
110 // When the 'axis' tensor is non-const we can't resize output tensors in
111 // Prepare(), and we have to do it now.
112 if (!IsConstantTensor(op_context.axis)) {
113 TF_LITE_ENSURE_OK(
114 context,
115 ResizeOutputTensors(context, node, op_context.axis, op_context.input,
116 op_context.params->num_splits));
117 }
118
119 int axis_value = GetTensorData<int>(op_context.axis)[0];
120 if (axis_value < 0) {
121 axis_value += NumDimensions(op_context.input);
122 }
123
124 TF_LITE_ENSURE(context, axis_value >= 0);
125 TF_LITE_ENSURE(context, axis_value < NumDimensions(op_context.input));
126
127 // TODO(b/173221795): Our usage of VectorOfTensors could be optimized by
128 // calculating it in Prepare, unless we defer shape calculation.
129 // We can improve the optimized_ops version to handle other
130 // cases too.
131#define TF_LITE_SPLIT(scalar) \
132 VectorOfTensors<scalar> all_outputs(*context, *node->outputs); \
133 tflite::SplitParams op_params; \
134 op_params.num_split = NumOutputs(node); \
135 op_params.axis = axis_value; \
136 reference_ops::Split(op_params, GetTensorShape(op_context.input), \
137 GetTensorData<scalar>(op_context.input), \
138 all_outputs.shapes(), all_outputs.data());
139
140 switch (op_context.input->type) {
141 case kTfLiteFloat32: {
142 TF_LITE_SPLIT(float);
143 break;
144 }
145 case kTfLiteUInt8: {
146 TF_LITE_SPLIT(uint8_t);
147 break;
148 }
149 case kTfLiteInt8: {
150 TF_LITE_SPLIT(int8_t);
151 break;
152 }
153 case kTfLiteInt16: {
154 TF_LITE_SPLIT(int16_t);
155 break;
156 }
157 case kTfLiteInt32: {
158 TF_LITE_SPLIT(int32_t);
159 break;
160 }
161 default:
162 TF_LITE_KERNEL_LOG(context, "Type %s currently not supported.",
163 TfLiteTypeGetName(op_context.input->type));
164 return kTfLiteError;
165 }
166#undef TF_LITE_SPLIT
167
168 return kTfLiteOk;
169}
170
171} // namespace split
172
173TfLiteRegistration* Register_SPLIT() {
174 static TfLiteRegistration r = {nullptr, nullptr, split::Prepare, split::Eval};
175 return &r;
176}
177
178} // namespace builtin
179} // namespace ops
180} // namespace tflite
181